{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 安装类库\n",
    "# !mkdir /home/aistudio/external-libraries\n",
    "# !pip install imgaug -t /home/aistudio/external-libraries\n",
    "import sys\n",
    "sys.path.append('/home/aistudio/external-libraries')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "image shape: (32, 32, 3)\n",
      "label value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMUAAADDCAYAAAAyYdXtAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAa/UlEQVR4nO2daWxc53WG33Nn5SbulChREi15lR1bThzXtetUWZw4aQAnRRskaAMDdZYCCdqg+WO4QJsC/ZECTYKiCFKkqBsHSOOkcVK7jtPYdZ06dhNZsi1rtyTSEsVFJMVlOJzhcJb79ccMHc59z5gjjjwk7fMABDln7tz73Ts8c+d855z3E+ccDMP4Dd5aD8Aw1hvmFIYRwJzCMAKYUxhGAHMKwwhgTmEYAWpyChG5W0ReFZEzInL/5RqUYawlsto8hYiEAJwCcBeAYQAHAHzKOXe80mu6urpcf3//qo5nrAS/j7nFRbKl0mmyNbdsUvcYDodrH9Yq8BVboZBXt11czJAtFObP+my2fLuJC5NIzCZF22ctZ30rgDPOuUEAEJGHAdwDoKJT9Pf34+DBgzUc0qhIgR3gwtAA2fa/8BLZ7vzA3eouOzq7ah/XChQUW7rA1uT8tPr6wYETZGvvbCLb0NDpssd/9tkHKo6plq9P2wCcX/Z4uGQrQ0Q+JyIHReTg5ORkDYczjPpQi1Notx66hzvnvu2cu8U5d0t3d3cNhzOM+lDL16dhANuXPe4DMHqpO7Haq0vHV75fS26GbMmJQbI989iPebskfy8HgD/+zGfYqLxfvq+8h8rHrVM+R3PKa0fHhsg2PTusjnHs/DGyDZ6+SLbEXPn1Wcyk1P0Btd0pDgC4SkSuEJEogE8CeKyG/RnGumDVdwrnXF5Evgjg5wBCAB50zrHbGsYGo6Y5N+fcEwCeuExjMYx1gWW0DSPA2mRnVkBEzam8LdGmITxRZvcLSX7tAk+BN/lZsk2NXVCPPX5hnGwh4c/R1rZWskWiEbL5SqDtHKfqwvxS5AoL6hg7N3eSbXySA+2xgfI5oFwup+4PsDuFYRDmFIYRwJzCMAKYUxhGgHUZaNcLrRrT+VxYl5/hwG0hMc+vjXIh2qZtW/WDKwGrKEGn53P2em7sPNnOHv012V47cZL350WV/XEGGQB+8cQjZGvfup1st99xJ784zJW3U7MJsi3Oc5CfyUyQzeV5IgEAJqY5az8zy++X84PXu/Jkjt0pDCOAOYVhBDCnMIwA5hSGEeBtHWjD58zwxTMcnE68+BzZ0tMcNF7I8mfM1XfuUw991U23kM2L8Ntx5NgRsr38zDNkSyrB99wEZ6Qj4RjZMlN6xf8zPz1Htut+90Nk++33vJ/3uciZ85kJ3t/gAS6dGx/ljsHOnTvUMaZ9LgHPpfk6Rr2essfyBv/6dqcwjADmFIYRwJzCMAKYUxhGAHMKwwhQ0+yTiJwFkERRvifvnOMplXWMy3BJx9SrPPOB2TkydYQUcS6PZ1wGn31KPXbYcZlBfCvPsHz3R/9JtmMHD5FtVzuXmHR4PMYmZYarEFIaGAAMnuJZqedO/YhsvX3Xk+3OW68j2+TJ/yPbK0/+hGyLsyzCkBrZo46xcc+72NbAelUtV7SXPY7GKssJXI4p2fc657jYxDA2KPb1yTAC1OoUDsCTIvKiiHxO28AUAo2NRq1OcYdz7p0APgzgCyLynuAGphBobDRqlbgZLf2eEJGfoCi6/Owl7WQNNQq8KPcWNPdw/8Pk8Gtky0yyYl1TlPsh5jL6CZ78tVI60r6TbE8++Txvl+Teghavl23tcbKlFjn4PjmkCxdcSLFswvAUB8Hf+86/8naHesiWPs/i2k0FLtOINXApymKK1dIBYGczB9Xe5ivJlpHy9zqkqSMsvb7iMysgIk0i0rL0N4APAji62v0ZxnqhljvFZgA/KcnRhAH8m3Puvy7LqAxjDalFNnMQwE2XcSyGsS6wKVnDCLD2/RSaBF61wXclFf8qX++U5au2vINvfrn5WbINDL1KtvQ0TzlnYw3qsU+d4hV4Us2sghfO8UnOTfGqPgll9Z74Tg6+52Y4UD58Tg+0J7M8EdHSymqAQ2deIdv+aZb3v6qLg9tohM9vdpFtLT36dRwb5T6STY0dfJyOgJKgVF4Cwu4UhhHAnMIwAphTGEYAcwrDCLDmgbYW7yhV1RVeewnr5Sny/qKstxaJcRZ426138P6UhOjYS5x97lMU9QBg6iKLJhze/zLZGsIcfHe1cAC8704e42/dxOXW//jNb5ItucAl74B+LTSlvrSSbY5tZ4l833HwPT7BZfnh9s1kkya9ROiVY1zqn3iRxSd6d+0qe5ya4+MuYXcKwwhgTmEYAcwpDCOAOYVhBKh7oB1ciFzzSl8JoDNZ7qeOKhlpQF+XzdPS3ErwnVfS5APT3G07owShi1ffQLbr33W7OsbcEGelf/jT/+btFri0+uN37yPb73/0g2Q7fYZl6idSHOBnXUgdY8TxttEwb9sS52vR1MaBcSLH59K0mbPuroFl/IcndSn+wgJPRGSVnvpnHisv4E7OcpXCEnanMIwA5hSGEcCcwjACmFMYRoAVA20ReRDARwFMOOduKNk6APwAQD+AswA+4ZzjmuQAvnNYzJVnT+NKn/RcmteTe/7AfrJtam5Wj3Pz9TeSraWhkWyFAvcrj0yyANgvnuMA+LUhXiduUckMx7b2q2PMJzm7O3GOpernk3wtdvdzljwMDopnExxwZn0OlPMFbfU/wE9zcOs5TuWH4vweTk3zv8P4BE9YNCjrBDa18qRKcxtvBwAtSuDfEObJku1dbWWPB87ryw8A1d0pvgPg7oDtfgBPO+euAvB06bFhvCVY0Smcc88CCM4f3gPgodLfDwH42GUel2GsGauNKTY758YAoPSb9UxKLBdDu2hiaMYG4E0PtJeLoXWZGJqxAVhtRntcRHqdc2Mi0guAVwNXEAEkEBjNzXMgeeDQS2QbGhshWyzKolkA0N3BAlnX9O8mW2JuimyHDrFI2djZ42S7MMRB48QMn8uhI6y0DQC39l1Ltl1b+ENjpoP7jVu7OAt8fpT7rMfGOJhMJTkAbmvW+59T8xxoz81wJn5XTx/ZmuP8r5VuUBTP8zzZUUjxGAueXuqdbecSdYR50qG1tfwcw6HK94PV3ikeA3Bv6e97ATy6yv0YxrpjRacQke8D+BWAa0RkWETuA/BVAHeJyGkAd5UeG8ZbghW/PjnnPlXhKV4n1jDeAlhG2zAC1LV03PlAYbE8CHp+/wu03YvHDpNt97UczI2e5wXeAeA/Hn+abB/9SI5sA2dZkGzgPCuMeyEujZ5WsrMjw2fJFi+8Wx3jO/r7yfanf/JpsmlZ6d1tLEg2OsoTEaeP8ARBcoqnxVs7lWAVQCGvlIQrye9t7S1kc8pSZ+Lzi0MeZ59DIaWkP8fvHwCkFaG6UJgz7AW/PKB30LP4gN0pDIMwpzCMAOYUhhHAnMIwAphTGEaAus4+FfwCkvPlM0b/8yz3KnRu5TKNxQz3H5wb1CXkRZnReOEwq/cdVWa5RLkkIe0yhbnmf9/795Ktp53LNAAgn+bZmRuuuYZsniKdP/xznl1ruMizMHe1cJ3mlqu51+Tg5Jg6xpMN3DvR38clJt1KSUcmwyUiWt+G7/OskrYeXSysl6JklZ6PqNI740X0kiANu1MYRgBzCsMIYE5hGAHMKQwjQF0DbfEEkabygKe1g8UHRkZYXv3wK7xE97kz3L8AAL19HJR1buFyCd/nWv6Zad5nRAnc+3cpQexWLndYWNTLE7IZDrQLivDBwlku30if5cA4keCAvEEpB3n3Di6X6Y3xuAFg0xT3Y4TbWUDAj/B1dAUOlkUJqgs5nkARLSZWBBeK++Teifwi7zPqBV9va94ZRtWYUxhGAHMKwwhgTmEYAVarEPgVAJ8FsFSc/4Bz7omV9pVKZ7D/5fIehoIiAx8K8bBeG+Q+h5ERPdBubmcBgEKhnWzJJK/VpgXaVyjBaU83B9rDw6fI1h7WJd8j1/NkQDjBsvLnDx0j27E5lrT/6XHeLuFzwNkW52zvB6+5RR3j7VFWIjw/fpZsoVYOqvON3BORUwJg5/PkgvP5/deCZwAoFJSMuFMy58FlG95gvcTVKgQCwDecc3tLPys6hGFsFFarEGgYb1lqiSm+KCKHReRBEeHvJiWWKwQm3mD1GMNYL6zWKb4FYDeAvQDGAHyt0obLFQJb29oqbWYY64ZVZbSdc+NLf4vIPwN4vJrXLWYX8NrZI+UDUGTTezq5dFyURvN4g57l/MD7PkS2a/fsIlthkZUIezoUGffeHWTr7uAs8K7tXPq9o3urOkZNoC4xylL8U3MsvjgIDi5bbuSS8PwCZ/Fnp1ns4dFzLHAAANf3cJn4FVq6+QJPECy0cqbZ5bncPp/nQNvPceBeqJCBTmd4YiTepKzV1xAc92XOaJekMpf4OACuwTCMDUo1U7LfB7APQJeIDAP4awD7RGQviu52FsDn38QxGkZdWa1C4L+8CWMxjHWBZbQNI0BdS8ejUR9b+8uDsvYuzrDmchx8fej3WGlvaooDPAAIx5UF1LO8z5tvvp5smRQHg6OK7P7e6/i1u/t3km32oi4hP3aBy7Knzw+TzbuS93nne/eRLeNxcDo3z9cnz5cGx149wkYAQ6+eIVtPiAPUTR5Pgjift/OEtxOlfN8pg8xXiIuzinJguKAoDObLr4VTst5L2J3CMAKYUxhGAHMKwwhgTmEYAeoaaCdTCTx74GdltrwSVO3o59LvvbfvIdu5AV0MzRMOWKfneX07v8AZ8WSCA7+pOQ6WX3iFM8MnBzjLPTKiB9pxpRT62hhL4ntNnBG/oJSYP3/gl2TLK7FkJMYl64l5fdXabISvTyLOAX04xNulwedXUPqpQ8GSbgBhxZZT1sYDAE/4cz2kLDifWSyfQPGViYDX91nxGcN4m2JOYRgBzCkMI4A5hWEEqGugHYuHsfvK8mAyp5QO92zRsrNcVp1M6Q2B4TCXN+cKvH5bIslBcE5JnXb0ceAfiXGgHYpz7/TOa/XPHb/A9pYwB+q/fI7X5Tt2mgXSWlq4V0U8RQ08yxn7qVn9OvqOX+8UFfWkooy+kOX+dxHONEejvD6dZltQVOcBIBzl/xXP42ubpyDfAm3DqBpzCsMIYE5hGAHMKQwjQDWdd9sBfBfAFgA+gG875/5BRDoA/ABAP4rdd59wznHEtYymhjhu2VvexzyvlDcfP/4K2aZnedfX7rlBPU5L8ybtTMgyMcnBVi7L2yVneQmpuRRngTs7tig2XehkPsOfR/EQB8vhRg6+Czm+ZlFh9fbGZlYI95RgfnbyvDrGtt5+srVH+V8mMc0icL7wBEosxgG0pwTf+TyXg2vtBADQpCzlVVBS+U3N5QrsnqcL6QHV3SnyAL7snLsOwG0AviAiewDcD+Bp59xVAJ4uPTaMDU81YmhjzrmXSn8nAZwAsA3APQAeKm32EICPvVmDNIx6ckkxhYj0A7gZwH4Am51zY0DRcQCwuCrKxdBmp3ke3zDWG1U7hYg0A3gEwJecc3rpp8JyMbS2Dv6OaxjrjaqcQkQiKDrE95xzPy6Zx5f0n0q/WbXLMDYg1cw+CYqSNiecc19f9tRjAO4F8NXS70dX2lfBzyMxXy4C4IFLMuYSPHtw8iTP9pwZ/F/1OH07WGHwxr27ybZD2a7B45krpzTCF5Q+kGiEexWEqxAAAI0LPPPV28hjvHkvz650tXKpxfPPPk+2xAxr92r9K5Mj+ueZa+L+jsLVPEYo10cTj4gpi8YvpLgcxC9w70Q0rn9+hxTlyOyCos4QrPKpXOVRVe3THQA+DeCIiBwq2R5A0Rl+KCL3ARgC8IdV7Msw1j3ViKE9B22Sv8j7L+9wDGPtsYy2YQQwpzCMAHXtp/AEaIyW+6HzOVC647Z3kW337uvINnjurHqciUkWLpidUiTbIxzkjy9wQN/WxsF3SwuXS7iIUiIyx30XANDRxOvodfdw30ZyOwfvB371K7JNzbKKoa9cWw3hVhMAQEcHP9GxjUtRUspHa0QRFIhqSycIR7wLC1zG4jw9Ms4rCoPaaacD+3yja2N3CsMIYE5hGAHMKQwjgDmFYQSoa6ANcfBC5YGRF1Gk3ZXFyru2bCPbdTfo68llMhyo+Yo63djFMbJNJDhgnZgbJ9uWXg6KW1s5MPUr1O3P5/jzaCrzAtlGprnM7Ohxzl4vZnjc8XiFCDpAU6sexG7vUHonkkNk89r4OG0RrhbwwT0RqsiA4/dqPqlfx5CnBO/KgoKUdK+UeYPdKQyDMKcwjADmFIYRwJzCMALUNdDOZBdxarR8HbXWNs4Mx7IcXG6Kc4NSu5JVBoC4UmbsgZvme9q5NDoS5gzyXJKz3CHHkdrcLJdqj0/yEgAAkBhnxcMzXSzY0Nd6M9n+6BPvIduRA/xabZ2/tnYWUlhUSt4BwM1yNv7o8cNk6+9m0YTOJi5vzyuKjlNKmfimCGfNnSJwAADzCRaViDfy/0rjpvIxep5eaQDYncIwCHMKwwhgTmEYAcwpDCNALQqBXwHwWQBLUegDzrkn3mhfBb+A2fnyIDqTZ4n1mCJzn2tpJVtyvpLKG5cFNzZw8NXc2Eu2eJSDxu5WLh3PKSp9mrT/8BleRB4AwopM/uFxVuo7rySlr45yGX2Hcn229nDG31NKrTONehA7FeHe7W3gyY2GMB+7oUlRNkzzyeQKrAaYzfByAbmsvuZdWlGYjMX42O3t5eqNoXBlnY1qZp+WFAJfEpEWAC+KyFOl577hnPv7KvZhGBuGanq0xwAsiZ4lRWRJIdAw3pLUohAIAF8UkcMi8qCIqErCyxUCUwm+VRrGeqMWhcBvAdgNYC+Kd5Kvaa9brhDYpFS/GsZ6o6qMtqYQ6JwbX/b8PwN4fKX9RCNx9G2+ssyWV2TTPaX0d2GBs7MTs7o2rZaB3r6TZfLTijR8Jsn7bG5WMradSjY8wsJlu3bq68k1NnPQOTjAZdCxsCKn38vXrG0zTwbMz3O2N1TgIHb39VeSDQD8k1zCncvzuOMxRQ7f4zF2NvN2YWUB+5mLXAUgPvfTA0B6gb99hGO8rRcq/1fX1t97fduKz/zmxapC4JJkZomPAzi60r4MYyNQi0Lgp0RkL4oChGcBfP5NGaFh1JlaFALfMCdhGBsVy2gbRoC6lo47V0A2Xx7IxmJcttzUwKXDhTxnNNMJVqwGgKZGDt4KOQ6qp9O8jl5cWdNNUw73PQ5C01nOsPds0dbfAxobOejcskUpty7wcRZ9zuJ2dnBP9EKCt4tHeNIg1MjbAUB8koPqhgt8Pp7PwXsBPGHhhfi9bmji9zqd4kmVSFwXLys4nlTxhYPvhXx5tYGv9IG/Ps6KzxjG2xRzCsMIYE5hGAHMKQwjQF0D7YJfQCpdnuHN+yzElZxn8bGQcGBaXN+eaW1hezrN+4woy01JmIP0VIYD6OQol4lrGWQo5wcAzudZ7pCiWu77SsCqzJAX0lxuHw5xcJpKc1CczOp95NLK2XRp4qA8dZED45wSyObBx15c4OuYcxwoD4+NqGO8MMEVA91bOaB36fKJmoJSQr+E3SkMI4A5hWEEMKcwjADmFIYRwJzCMALUt8zD95BbKC8TSM1zA7m2uHg2y7MrUaXUAgBmXuPyj7kUz17c8I6ryZa4wDMxnvBlUtdMU2aUXhvQZ01iUZ5Na+vgWZPWdv7cam3jkhVkeZYqrpSSJOZZKCKd5tkjAHALishBhGfscuDSDz+niBSE+H3JhXn2KZ3jGaXBIRZ1AIBkgv8H2vq4nyLvlZ+je4PV5e1OYRgBzCkMI4A5hWEEqKYdNS4iL4jIKyJyTET+pmS/QkT2i8hpEfmBiChfdA1j41FNoL0I4H3OufmSgMFzIvIzAH+BohjawyLyTwDuQ1HhoyK5rI/R4fJSCF8JTqMRLi8YGeMAOJvVRQHCipx+WzsHgyNjSjmJx+PxwPtrVPoSNHXBcEyX9Tl55iTZtmZ4jOGLXBoRiXCQ39zIqnhNTazct7DAgXYoWqlXgYPg5ngfb+cpDSfKAvEzeb7e0sOlMdPz/F4n5/UxZhx/rve/kxUUb7h5Z9njQ0eeVPcHVHGncEWWin8ipR8H4H0AflSyPwTgYyvtyzA2AlXFFCISKokWTAB4CsAAgFnn3NKc3TAqqAYuF0NLz+tTf4axnqjKKZxzBefcXgB9AG4FwPcn6BO/y8XQGpst7DDWP5c0++ScmwXwCwC3AWgTeT2r1QdAl9c2jA1GNVL83QByzrlZEWkA8AEAfwfgGQB/AOBhAPcCeHSlfS0u5jAwUL6guyiy+S3NbJubYf9NJvWvY3uURef7d7Ki3/DoWT52C0viuhzfBBubOCiOKcF3/w5dia6jgzO+mQxnfGeVdecSM4qqYoeyTlyOe0M8j4+bSPHC9ACQLXCWfDbBQgGbUpw5jykBcMbj/cWivF0iqfSBpPTP79Zt/O0j3q0IVzSXTzA4pddkiWpmn3oBPCQiIRTvLD90zj0uIscBPCwifwvgZRRVBA1jw1ONGNphFJXGg/ZBFOMLw3hLYRltwwhgTmEYAcS5yiW0l/1gIpMAzgHoAqBHdxsPO5f1yUrnstM51609UVeneP2gIgedc7fU/cBvAnYu65NazsW+PhlGAHMKwwiwVk7x7TU67puBncv6ZNXnsiYxhWGsZ+zrk2EEMKcwjAB1dwoRuVtEXhWRMyJyf72PXwsi8qCITIjI0WW2DhF5qtSW+5SIcEXhOkREtovIMyJyotRm/Ocl+4Y7n8vdMl1XpygVFX4TwIcB7EFxhdU99RxDjXwHwN0B2/0AnnbOXQXg6dLjjUAewJedc9eh2ArwhdJ7sRHPZ6ll+iYAewHcLSK3oVjN/Y3Sucyg2DK9IvW+U9wK4IxzbtA5l0Wx7PyeOo9h1TjnngUQbAy/B8V2XGADteU658accy+V/k4COIFi9+SGO5/L3TJdb6fYBmC51FvFNtYNxGbn3BhQ/EcD0LPG47lkRKQfxUro/dig51NLy3SQejuF1nFjc8JriIg0A3gEwJecU+Q7Ngi1tEwHqbdTDAPYvuzxW6GNdVxEegGg9JvFcdcpJcmiRwB8zzn345J5w54PcHlapuvtFAcAXFWaFYgC+CSAx+o8hsvNYyi24wJVtuWuB0REUOyWPOGc+/qypzbc+YhIt4i0lf5eapk+gd+0TAOXci7Oubr+APgIgFMofuf7y3ofv8axfx/AGIAcine9+wB0ojhLc7r0u2Otx1nlufwOil8nDgM4VPr5yEY8HwA3otgSfRjAUQB/VbLvAvACgDMA/h1ArJr9WZmHYQSwjLZhBDCnMIwA5hSGEcCcwjACmFMYRgBzCsMIYE5hGAH+HzSx9wB11O8hAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 读取数据\n",
    "reader = paddle.batch(\n",
    "    paddle.dataset.cifar.train100(),\n",
    "    batch_size=8) # 数据集读取器\n",
    "data = next(reader()) # 读取数据\n",
    "index = 0 # 批次索引\n",
    "\n",
    "# 读取图像\n",
    "image = np.array([x[0] for x in data]).astype(np.float32) # 读取图像数据，数据类型为float32\n",
    "image = image * 255 # 从[0,1]转换到[0,255]\n",
    "image = image[index].reshape((3, 32, 32)).transpose((1, 2, 0)).astype(np.uint8) # 数据格式从CHW转换为HWC，数据类型转换为uint8\n",
    "print('image shape:', image.shape)\n",
    "\n",
    "# 图像增强\n",
    "# sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "# seq = iaa.Sequential([\n",
    "#     sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "#     iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "# image = seq(image=image)\n",
    "\n",
    "# 读取标签\n",
    "label = np.array([x[1] for x in data]).astype(np.int64) # 读取标签数据，数据类型为int64\n",
    "vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "         'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "         'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "         'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "         'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "         'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "         'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "         'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "         'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "         'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "         'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "         'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "         'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "         'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "         'baby', 'boy', 'girl', 'man', 'woman',\n",
    "         'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "         'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "         'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "         'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "         'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "vlist.sort() # 字母上升排序\n",
    "print('label value:', vlist[label[index]])\n",
    "\n",
    "# 显示图像\n",
    "image = Image.fromarray(image)   # 转换图像格式\n",
    "image.save('./work/out/img.png') # 保存读取图像\n",
    "plt.figure(figsize=(3, 3))       # 设置显示大小\n",
    "plt.imshow(image)                # 设置显示图像\n",
    "plt.show()                       # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n",
      "valid_data: image shape (128, 3, 32, 32), label shape:(128, 1)\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "import numpy as np\n",
    "import imgaug as ia\n",
    "import imgaug.augmenters as iaa\n",
    "\n",
    "# 训练数据增强\n",
    "def train_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 增强图像\n",
    "    sometimes = lambda aug: iaa.Sometimes(0.5, aug) # 随机进行图像增强\n",
    "    seq = iaa.Sequential([\n",
    "        sometimes(iaa.CropAndPad(px=(-4, 4))),      # 随机裁剪填充像素\n",
    "        iaa.Fliplr(0.5)])                           # 随机进行水平翻转\n",
    "    images = seq(images=images)\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 验证数据增强\n",
    "def valid_augment(images):\n",
    "    # 转换格式\n",
    "    images = images * 255 # 从[0,1]转换到[0,255]\n",
    "    images = images.transpose((0, 2, 3, 1)).astype(np.uint8) # 数据格式从BCHW转换为BHWC，数据类型转换为uint8\n",
    "    \n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    images = (images/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    images = images.transpose((0, 3, 1, 2)).astype(np.float32) # 数据格式从BHWC转换为BCHW，数据类型转换为float32\n",
    "    \n",
    "    return images\n",
    "\n",
    "# 读取训练数据\n",
    "train_reader = paddle.batch(\n",
    "    paddle.reader.shuffle(paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "train_data = next(train_reader()) # 读取训练数据\n",
    "\n",
    "train_image = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取训练图像\n",
    "train_image = train_augment(train_image)                                                       # 训练图像增强\n",
    "train_label = np.array([x[1] for x in train_data]).reshape((-1, 1)).astype(np.int64)           # 读取训练标签\n",
    "print('train_data: image shape {}, label shape:{}'.format(train_image.shape, train_label.shape))\n",
    "\n",
    "# 读取验证数据\n",
    "valid_reader = paddle.batch(\n",
    "    paddle.dataset.cifar.test100(),\n",
    "    batch_size=128) # 构造数据读取器\n",
    "valid_data = next(valid_reader()) # 读取验证数据\n",
    "\n",
    "valid_image = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取验证图像\n",
    "valid_image = valid_augment(valid_image)                                                       # 验证图像增强\n",
    "valid_label = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64)           # 读取验证标签\n",
    "print('valid_data: image shape {}, label shape:{}'.format(valid_image.shape, valid_label.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型设计"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear, BatchNorm\n",
    "import math\n",
    "\n",
    "# 模组结构：输入维度，输出维度，滑动步长，基础长度, 队列长度\n",
    "group_arch = [(3, 128, 1, 2, 5), (512, 256, 2, 2, 5), (1024, 512, 2, 2, 5)]\n",
    "group_dim  = 2048 # 模组输出维度\n",
    "class_dim  = 100  # 类别数量维度\n",
    "\n",
    "# 卷积单元\n",
    "class ConvUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=3, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化卷积单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ConvUnit, self).__init__()\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=filter_size,\n",
    "            stride=stride,\n",
    "            padding=(filter_size-1)//2,                       # 输出特征图大小不变\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 投影单元\n",
    "class ProjUnit(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, filter_size=1, stride=1, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化投影单元，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim      - 输入维度\n",
    "            out_dim     - 输出维度\n",
    "            filter_size - 卷积大小\n",
    "            stride      - 滑动步长\n",
    "            act         - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(ProjUnit, self).__init__()\n",
    "        \n",
    "        # 添加池化\n",
    "        self.pool = Pool2D(\n",
    "            pool_size=filter_size,\n",
    "            pool_stride=stride,\n",
    "            pool_padding=0,\n",
    "            pool_type='avg')\n",
    "        \n",
    "        # 添加卷积\n",
    "        self.conv = Conv2D(\n",
    "            num_channels=in_dim,\n",
    "            num_filters=out_dim,\n",
    "            filter_size=1,\n",
    "            stride=1,\n",
    "            padding=0,\n",
    "            param_attr=fluid.initializer.MSRA(uniform=False), # 使用MARA 初始权重\n",
    "            bias_attr=False,                                  # 卷积输出没有偏置项\n",
    "            act=None)\n",
    "        \n",
    "        # 添加正则\n",
    "        self.norm = BatchNorm(\n",
    "            num_channels=out_dim,\n",
    "            param_attr=fluid.initializer.Constant(1.0), # 使用常量初始化权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),  # 使用常量初始化偏置\n",
    "            act=act)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征进行池化卷积和正则\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 进行池化\n",
    "        x = self.pool(x)\n",
    "        \n",
    "        # 进行卷积\n",
    "        x = self.conv(x)\n",
    "        \n",
    "        # 进行正则\n",
    "        x = self.norm(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "# 分割结构\n",
    "class HSBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, splits=5, act=None):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始HS-Block结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长，1保持不变，2下采样\n",
    "            splits  - 分割次数\n",
    "            act     - 激活函数\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(HSBlock, self).__init__()\n",
    "        \n",
    "        # 计算通道\n",
    "        channel0 = out_dim // splits\n",
    "        channel1 = channel0 * 2\n",
    "        channel2 = channel0 * splits\n",
    "        \n",
    "        # 特征平分\n",
    "        self.conv1 = ConvUnit(in_dim=in_dim, out_dim=channel2, filter_size=1, stride=1, act=act)\n",
    "        \n",
    "        # 特征升维\n",
    "        self.conv2 = ConvUnit(in_dim=channel0, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        \n",
    "        # 重复合并\n",
    "        self.conv3 = ConvUnit(in_dim=channel1, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        self.conv4 = ConvUnit(in_dim=channel1, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        self.conv5 = ConvUnit(in_dim=channel1, out_dim=channel1, filter_size=3, stride=1, act=act)\n",
    "        \n",
    "        # 合并特征\n",
    "        self.conv6 = ConvUnit(in_dim=channel2 + channel0, out_dim=out_dim, filter_size=1, stride=1, act=act)\n",
    "            \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "        \"\"\"\n",
    "        # 特征平分\n",
    "        x = self.conv1(x)\n",
    "        x0, x1, x2, x3, x4 = fluid.layers.split(input=x, num_or_sections=5, dim=1)\n",
    "        \n",
    "        # 特征升维\n",
    "        x1 = self.conv2(x1)\n",
    "        x1_0, x1_1 = fluid.layers.split(input=x1, num_or_sections=2, dim=1)\n",
    "        \n",
    "        # 重复合并\n",
    "        x2 = fluid.layers.concat(input=[x2, x1_1], axis=1)\n",
    "        x2 = self.conv3(x2)\n",
    "        x2_0, x2_1 = fluid.layers.split(input=x2, num_or_sections=2, dim=1)\n",
    "        \n",
    "        x3 = fluid.layers.concat(input=[x3, x2_1], axis=1)\n",
    "        x3 = self.conv4(x3)\n",
    "        x3_0, x3_1 = fluid.layers.split(input=x3, num_or_sections=2, dim=1)\n",
    "        \n",
    "        x4 = fluid.layers.concat(input=[x4, x3_1], axis=1)\n",
    "        x4 = self.conv5(x4)\n",
    "        \n",
    "        # 合并特征\n",
    "        x = fluid.layers.concat(input=[x0, x1_0, x2_0, x3_0, x4], axis=1)\n",
    "        x = self.conv6(x)\n",
    "        \n",
    "        return x\n",
    "    \n",
    "# 基础结构\n",
    "class SSRBasic(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, queues=1, is_pass=True):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化基础结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            queues  - 队列长度\n",
    "            is_pass - 是否直连\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBasic, self).__init__()\n",
    "        \n",
    "        # 是否直连标识\n",
    "        self.is_pass = is_pass\n",
    "        \n",
    "        # 添加投影路径\n",
    "        self.proj = ProjUnit(in_dim=in_dim, out_dim=out_dim*4, filter_size=stride, stride=stride, act=None)\n",
    "        \n",
    "        # 添加卷积路径\n",
    "        self.con1 = ConvUnit(in_dim=in_dim, out_dim=out_dim, filter_size=1, stride=stride, act='relu')\n",
    "        \n",
    "        if queues==1:\n",
    "            self.con2 = ConvUnit(in_dim=out_dim, out_dim=out_dim, filter_size=3, stride=1, act='relu')\n",
    "        else:\n",
    "             self.con2 = HSBlock(in_dim=out_dim, out_dim=out_dim, stride=1, splits=queues, act='relu')\n",
    "        \n",
    "        self.con3 = ConvUnit(in_dim=out_dim, out_dim=out_dim*4, filter_size=1, stride=1, act=None)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x - 输入特征\n",
    "        输出:\n",
    "            x - 输出特征\n",
    "            y - 输出特征\n",
    "        \"\"\"\n",
    "        # 直连路径\n",
    "        if self.is_pass: # 是否直连\n",
    "            x_pass = x\n",
    "        else:            # 否则投影\n",
    "            x_pass = self.proj(x)\n",
    "        \n",
    "        # 卷积路径\n",
    "        x_con1 = self.con1(x)      # 特征降维\n",
    "        x_con2 = self.con2(x_con1) # 特征提取\n",
    "        x_con3 = self.con3(x_con2) # 特征升维\n",
    "        \n",
    "        # 输出特征\n",
    "        x = fluid.layers.elementwise_add(x=x_pass, y=x_con3, act='relu') # 直连路径与卷积路径进行特征相加\n",
    "        y = x\n",
    "        \n",
    "        return x, y\n",
    "    \n",
    "# 模块结构\n",
    "class SSRBlock(fluid.dygraph.Layer):\n",
    "    def __init__(self, in_dim, out_dim, stride=1, basics=1, queues=1):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模块结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "            in_dim  - 输入维度\n",
    "            out_dim - 输出维度\n",
    "            stride  - 滑动步长\n",
    "            basics  - 基础长度\n",
    "            queues  - 队列长度\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRBlock, self).__init__()\n",
    "        \n",
    "        # 添加模块列表\n",
    "        self.block_list = [] # 模块列表\n",
    "        for i in range(basics):\n",
    "            block_item = self.add_sublayer( # 构造模块项目\n",
    "                'block_' + str(i),\n",
    "                SSRBasic(\n",
    "                    in_dim=(in_dim if i==0 else out_dim*4), # 每组模块项目除第一块外，输入维度=输出维度\n",
    "                    out_dim=out_dim,\n",
    "                    stride=(stride if i==0 else 1), # 每组模块项目除第一块外，stride=1\n",
    "                    queues=queues,\n",
    "                    is_pass=(False if i==0 else True))) # 每组模块项目除第一块外，is_pass=True\n",
    "            self.block_list.append(block_item) # 添加模块项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模块输出列表\n",
    "        for block_item in self.block_list:\n",
    "            x, y_item = block_item(x) # 提取模块特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "\n",
    "# 模组结构\n",
    "class SSRGroup(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化模组结构，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRGroup, self).__init__()\n",
    "        \n",
    "        # 添加模组列表\n",
    "        self.group_list = [] # 模组列表\n",
    "        for i, block_arch in enumerate(group_arch):\n",
    "            group_item = self.add_sublayer( # 构造模组项目\n",
    "                'group_' + str(i),\n",
    "                SSRBlock(\n",
    "                    in_dim=block_arch[0],\n",
    "                    out_dim=block_arch[1],\n",
    "                    stride=block_arch[2],\n",
    "                    basics=block_arch[3],\n",
    "                    queues=block_arch[4]))\n",
    "            self.group_list.append(group_item) # 添加模组项目\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入的特征图像提取特征\n",
    "        输入:\n",
    "            x      - 输入特征\n",
    "        输出:\n",
    "            x      - 输出特征\n",
    "            y_list - 输出特征列表\n",
    "        \"\"\"\n",
    "        y_list = [] # 模组输出列表\n",
    "        for group_item in self.group_list:\n",
    "            x, y_item = group_item(x) # 提取模组特征\n",
    "            y_list.append(y_item) # 添加输出列表\n",
    "            \n",
    "        return x, y_list\n",
    "        \n",
    "# 分割网络\n",
    "class SSRNet(fluid.dygraph.Layer):\n",
    "    def __init__(self):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            初始化分割网络，H/W=(H/W+2*P-F)/S+1\n",
    "        输入:\n",
    "        输出:\n",
    "        \"\"\"\n",
    "        super(SSRNet, self).__init__()\n",
    "        \n",
    "        # 添加模组结构\n",
    "        self.backbone = SSRGroup() # 输出：N*C*H*W\n",
    "        \n",
    "        # 添加全连接层\n",
    "        self.pool = Pool2D(global_pooling=True, pool_type='avg') # 输出：N*C*1*1\n",
    "        \n",
    "        stdv = 1.0/(math.sqrt(group_dim)*1.0)                    # 设置均匀分布权重方差\n",
    "        self.fc = Linear(                                        # 输出：=N*10\n",
    "            input_dim=group_dim,\n",
    "            output_dim=class_dim,\n",
    "            param_attr=fluid.initializer.Uniform(-stdv, stdv),   # 使用均匀分布初始权重\n",
    "            bias_attr=fluid.initializer.Constant(0.0),           # 使用常量数值初始偏置\n",
    "            act='softmax')\n",
    "    \n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        功能:\n",
    "            对输入图像进行分类\n",
    "        输入:\n",
    "            x - 输入图像\n",
    "        输出:\n",
    "            x - 预测结果\n",
    "        \"\"\"\n",
    "        # 提取特征\n",
    "        x, y_list = self.backbone(x)\n",
    "        \n",
    "        # 进行预测\n",
    "        x = self.pool(x)\n",
    "        x = fluid.layers.reshape(x, [x.shape[0], -1])\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tatol param: 18148924\n",
      "infer shape: [1, 100]\n"
     ]
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from paddle.fluid.dygraph.base import to_variable\n",
    "import numpy as np\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 输入数据\n",
    "    x = np.random.randn(1, 3, 32, 32).astype(np.float32)\n",
    "    x = to_variable(x)\n",
    "    \n",
    "    # 进行预测\n",
    "    backbone = SSRNet() # 设置网络\n",
    "    \n",
    "    infer = backbone(x) # 进行预测\n",
    "    \n",
    "    # 显示结果\n",
    "    parameters = 0\n",
    "    for p in backbone.parameters():\n",
    "        parameters += np.prod(p.shape) # 统计参数\n",
    "    \n",
    "    print('tatol param:', parameters)\n",
    "    print('infer shape:', infer.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd8HNW58PHfs0Vadcm2XGVZLoCLXLCNC6bYplfDpYQeTIgTCDXJmwBJuLSElnspFxIgphcbYqoxPZiWgMG9G3dbcpNkS7a6dve8f5xRsayykiWtVnq+n8/C7szZmWd25GfOnHNmRowxKKWU6jxc4Q5AKaVU29LEr5RSnYwmfqWU6mQ08SulVCejiV8ppToZTfxKKdXJaOJXSqlORhO/Ukp1Mpr4lVKqk/GEa8XdunUzGRkZ4Vq9UkpFpEWLFuUaY1IPZxlhS/wZGRksXLgwXKtXSqmIJCJbD3cZITf1iIhbRJaIyPt1zLtaRHJEZKnzuvZwA1NKKdU6mlLjvxlYAyTWM/91Y8wNhx+SUkqp1hRSjV9E0oCzgJmtG45SSqnWFmqN/1Hgd0BCA2UuEJETgB+BW40x25saTEVFBVlZWZSWljb1q6oWn89HWloaXq833KEopdqZRhO/iJwN7DHGLBKRyfUUmwvMMsaUicgvgBeBqXUsawYwAyA9Pf2QhWRlZZGQkEBGRgYiEvpWqIMYY8jLyyMrK4v+/fuHOxylVDsTSlPPJOBcEdkCzAamisgrNQsYY/KMMWXOx5nAmLoWZIx5xhgz1hgzNjX10NFIpaWldO3aVZP+YRIRunbtqmdOSqk6NZr4jTG3G2PSjDEZwCXA58aYK2qWEZFeNT6ei+0EbhZN+i1Df0elVH2aPY5fRO4BFhpj3gNuEpFzAT+wF7i6ZcI7VGlFgPziCrrFR+Fx64XHSinVVE3KnMaYL4wxZzvv73SSfuVZwTBjzEhjzBRjzNrWCBagrCLAngOlVARa/lnBeXl5jBo1ilGjRtGzZ0/69OlT9bm8vDykZUyfPp1169aFvM6ZM2dyyy23NDdkpZRqsrBdudtcLpdtwgi2wkPiu3btytKlSwG46667iI+P57e//e1BZYwxGGNwueo+Zj7//PMtHpdSSrWkiGsrcUnrJf76bNiwgaFDh3L55ZczbNgwdu7cyYwZMxg7dizDhg3jnnvuqSp73HHHsXTpUvx+P8nJydx2222MHDmSiRMnsmfPngbXs3nzZqZMmcKIESM45ZRTyMrKAmD27NlkZmYycuRIpkyZAsCKFSs45phjGDVqFCNGjGDTpk2t9wMopTqUdlvjv3vuKlbv2H/I9KAxlJQH8HnduF1N68Ac2juR/z5nWLPiWbt2LS+99BJjx44F4IEHHqBLly74/X6mTJnChRdeyNChQw/6TkFBASeeeCIPPPAAv/71r3nuuee47bbb6l3H9ddfz7XXXsvll1/OM888wy233MKcOXO4++67+eKLL+jRowf5+fkA/O1vf+O3v/0tP/nJTygrK8O04YFQKRXZIq7GX5nq2zrNDRw4sCrpA8yaNYvRo0czevRo1qxZw+rVqw/5TkxMDGeccQYAY8aMYcuWLQ2uY8GCBVxyySUAXHXVVXz99dcATJo0iauuuoqZM2cSDAYBOPbYY7nvvvt46KGH2L59Oz6fryU2UynVCbTbGn99NfOKQJA1O/fTOzmGbvHRbRZPXFxc1fv169fz2GOP8f3335OcnMwVV1xR55j5qKioqvdutxu/39+sdf/jH/9gwYIFvP/++4wePZolS5Zw5ZVXMnHiRObNm8fpp5/Oc889xwknnNCs5SulOpeIq/FXtu6Es2Vj//79JCQkkJiYyM6dO/n4449bZLkTJkzgjTfeAOCVV16pSuSbNm1iwoQJ3HvvvaSkpJCdnc2mTZsYNGgQN998M2effTbLly9vkRiUUh1fu63x1y9cjT3VRo8ezdChQxk8eDD9+vVj0qRJLbLcJ598kmuuuYb777+fHj16VI0QuvXWW9m8eTPGGE499VQyMzO57777mDVrFl6vl969e3PXXXe1SAxKqY5PwtUpOHbsWFP7QSxr1qxhyJAhDX4vaAwrswvomeije6K2azcklN9TKRVZRGSRMWZs4yXrF3FNPeGv7yulVGSLvMTvjOPfvV9vQKaUUs0RcYlfKaXU4dHEr5RSnUxEJ369WlUppZouwhN/uCNQSqnIE9GJPzu/pEWXN2XKlEMuxnr00Ue57rrrGvxefHw8ADt27ODCCy+ss8zkyZOpPXy1oelKKdVaIjrx55dUtOjyLr30UmbPnn3QtNmzZ3PppZeG9P3evXszZ86cFo1JKaVaWsiJX0TcIrJERN6vY160iLwuIhtEZIGIZLRkkLUlUkx3yW/xwfwXXngh8+bNq3roypYtW9ixYwfHH388hYWFnHTSSYwePZrhw4fz7rvvHvL9LVu2kJmZCUBJSQmXXHIJQ4YM4fzzz6ekpPGzk1mzZjF8+HAyMzP5/e9/D0AgEODqq68mMzOT4cOH88gjjwDw+OOPM3ToUEaMGFF1YzellApFU27ZcDP2WbqJdcz7GbDPGDNIRC4BHgR+cliRfXgb7FpR56w+5SV4CBBvfBDdhE3oORzOeKDe2V26dGHcuHF8+OGHTJs2jdmzZ3PxxRcjIvh8Pt5++20SExPJzc1lwoQJnHvuufU+2/bvf/87sbGxrFmzhuXLlzN69OgGQ9uxYwe///3vWbRoESkpKZx66qm888479O3bl+zsbFauXAlQdVvmBx54gM2bNxMdHV01TSmlQhFSjV9E0oCzgJn1FJkGvOi8nwOcJK34tG9T478trWZzT81mHmMMd9xxByNGjODkk08mOzub3bt317ucr776iiuusM+kHzFiBCNGjGhwvT/88AOTJ08mNTUVj8fD5ZdfzldffcWAAQPYtGkTN954Ix999BGJiYlVy7z88st55ZVX8Hgi8JZLSqmwCTVjPAr8DkioZ34fYDuAMcYvIgVAVyC3ZiERmQHMAEhPT294jQ3UzAt2bCWVvWwO9md4WnJoWxCiadOmceutt7J48WKKi4sZM2YMAK+++io5OTksWrQIr9dLRkZGnbdibmkpKSksW7aMjz/+mKeeeoo33niD5557jnnz5vHVV18xd+5c/vznP7NixQo9ACilQtJojV9Ezgb2GGMWHe7KjDHPGGPGGmPGpqamNns5UV6b4FwEDzekQ8THxzNlyhSuueaagzp1CwoK6N69O16vl/nz57N169YGl3PCCSfw2muvAbBy5cpGb5s8btw4vvzyS3JzcwkEAsyaNYsTTzyR3NxcgsEgF1xwAffddx+LFy8mGAyyfft2pkyZwoMPPkhBQQGFhYWHv/FKqU4hlCriJOBcETkT8AGJIvKKMeaKGmWygb5Aloh4gCQgr8WjdRjnVm3Sis09559//kEjfC6//HLOOecchg8fztixYxk8eHCDy7juuuuYPn06Q4YMYciQIVVnDvXp1asXDzzwAFOmTMEYw1lnncW0adNYtmwZ06dPr3ry1v33308gEOCKK66goKAAYww33XQTyckte+ajlOq4mnRbZhGZDPzWGHN2rem/AoYbY37pdO7+lzHm4oaW1dzbMgMc2LubhNIdbCCdQb27hhx/Z6O3ZVaq4wnrbZlF5B4ROdf5+CzQVUQ2AL8G6n+ieAuIj7GPNIyPiujLEJRSKiya1BtojPkC+MJ5f2eN6aXARS0ZWENE3M6KW76NXymlOrp2V2UOqelJbNjSCp27HYXewE4pVZ92lfh9Ph95eXmNJ63KSwS0xl8nYwx5eXn4fPpoSqXUodrVwO+0tDSysrLIyclpuGCgAg7sYb+rjL15e9smuAjj8/lIS0sLdxhKqXaoXSV+r9dL//79Gy+Yvw0ePZ47gr/kL/c82PqBKaVUB9KumnpC5o0FwB0o1bZspZRqoghN/DEAxFJGIKiJXymlmiIyE7/HJv4YKcOviV8ppZokMhO/y4Xf5cNHuSZ+pZRqoshM/IDf7bNNPQFN/Eop1RQRm/gDbh8xlOEP6lh+pZRqiohN/H53DDGiTT1KKdVUEZv4A+4Yp8aviV8ppZoiYhN/0GOberSNXymlmiZiE3/AE+s09Wgbv1JKNUXEJv7KGr829SilVNNEbOI3HqeNX5t6lFKqSUJ52LpPRL4XkWUiskpE7q6jzNUikiMiS53Xta0TbrWgx47q0Vs2KKVU04Ryd84yYKoxplBEvMA3IvKhMea7WuVeN8bc0PIh1s14Y4mlVNv4lVKqiRpN/Mbe/rLQ+eh1XmGvZhtvHLGU4Q9o4ldKqaYIqY1fRNwishTYA3xqjFlQR7ELRGS5iMwRkb71LGeGiCwUkYWNPmylMVFxuMRQUVp0eMtRSqlOJqTEb4wJGGNGAWnAOBHJrFVkLpBhjBkBfAq8WM9ynjHGjDXGjE1NTT2cuPHGJAJQWlRwWMtRSqnOpkmjeowx+cB84PRa0/OMMWXOx5nAmJYJr35RsQkAlBTtb+1VKaVUhxLKqJ5UEUl23scApwBra5XpVePjucCalgyyLr64JACKCzXxK6VUU4RS4+8FzBeR5cAP2Db+90XkHhE51ylzkzPUcxlwE3B164RbrTLxf75sY2uvSimlOpRQRvUsB46uY/qdNd7fDtzesqE1zO2LB6CHz9+Wq1VKqYgXsVfuEhVn/xcsCXMgSikVWSI/8QeKwxyIUkpFlghO/HZUT3RAa/xKKdUUEZz4nRq/0cSvlFJNEbmJ3xNNADfR2tSjlFJNErmJX4Rydww+rfErpVSTRG7iB8pdsURr4ldKqSaJ6MRf4Y4hxpSGOwyllIooEZ34y91xxGiNXymlmiSiE7/fHUMMpdhHBiillApFRCf+gCeWOEopLNPbNiilVKgiOvHHJSQTRylLtuWHOxSllIoYEZ34PTEJxEkppRWBcIeilFIRI6ITP9HxxFJKRUDb+JVSKlQRnfglKo44KaPCr238SikVqlCewOUTke9FZJnzsJW76ygTLSKvi8gGEVkgIhmtEewh6422N2rL3pPbFqtTSqkOIZQafxkw1RgzEhgFnC4iE2qV+RmwzxgzCHgEeLBlw6ybK9o+jOWFL1a1xeqUUqpDaDTxG6vQ+eh1XrUb1acBLzrv5wAniYi0WJT1cDuJP0706l2llApVSG38IuIWkaXAHuwzdxfUKtIH2A5gjPEDBUDXlgy0Lq4YJ/GjiV8ppUIVUuI3xgSMMaOANGCciGQ2Z2UiMkNEForIwpycnOYs4iBun23j18SvlFKha9KoHmNMPjAfOL3WrGygL4CIeIAkIK+O7z9jjBlrjBmbmpravIhr8Didu7Ha1KOUUiELZVRPqogkO+9jgFOAtbWKvQf81Hl/IfC5aYMb6Lh8tqknXmv8SikVMk8IZXoBL4qIG3ugeMMY876I3AMsNMa8BzwLvCwiG4C9wCWtFnFNUTbxa41fKaVC12jiN8YsB46uY/qdNd6XAhe1bGghcJ6729Vb0earVkqpSBXRV+5W1fj1nvxKKRWyyE78nij84sWnT+FSSqmQRXbip/Lxi1rjV0qpUEV+4nfF4tNRPUopFbLIT/yeGOIoIRjUWzMrpVQoIj/xu+OIo5SKYDDcoSilVESI+MS/zx9FrJTx4Ypd4Q5FKaUiQsQn/v2BKOIoYWNOYeOFlVJKRX7iPzK9F3GUktE1LtyhKKVURIj4xB8Vk0CslOHXNn6llApJxCd+iY4nnhLK9YHrSikVkshP/DHJ+KSCYFlxuENRSqmIEPGJ3xVnH/TlKssPcyRKKRUZIj/xx3YBwKuJXymlQhLxid8TbxO/WxO/UkqFJOITv8RU1vj3hTkSpZSKDKE8erGviMwXkdUiskpEbq6jzGQRKRCRpc7rzrqW1Sqcpp7vVm9ss1UqpVQkC+XRi37gN8aYxSKSACwSkU+NMatrlfvaGHN2y4fYCKfGn4JeuauUUqFotMZvjNlpjFnsvD8ArAH6tHZgIfPGUGa8JMuBcEeilFIRoUlt/CKSgX3+7oI6Zk8UkWUi8qGIDGuB2EINilJvEr28+jAWpZQKRShNPQCISDzwJnCLMWZ/rdmLgX7GmEIRORN4BziijmXMAGYApKenNzvo2ko8SST4tcavlFKhCKnGLyJebNJ/1RjzVu35xpj9xphC5/0HgFdEutVR7hljzFhjzNjU1NTDDL1aqSeRRDTxK6VUKEIZ1SPAs8AaY8z/1lOmp1MOERnnLDevJQNtSKk3iUSjnbtKKRWKUJp6JgFXAitEZKkz7Q4gHcAY8xRwIXCdiPiBEuASY0yb3TWtzJtEb63xK6VUSBpN/MaYbwBppMwTwBMtFVRTlXmTSKIQEwwiroi/Jk0ppVpVh8iSZd5koiTAnry94Q5FKaXavQ6R+Ffvs5vxyHvfhTkSpZRq/zpE4s8J2scuRlcUhDkSpZRq/zpE4t8bsIm/i+jIHqWUakyHSPzjhtlrxfrFloU5EqWUav86ROI/79hMAA7s2xPmSJRSqv3rEIk/Kt4+fnHP7p1hjkQppdq/DpH4Xd5oCo2PZCkKdyhKKdXudYjED5BPPMlygCtm1nXjUKWUUpU6TuI38aRQyDcbcmnDu0UopVTE6TCJf5+JJ9kZzvnljzlhjkYppdqvDpP4C4gn2Xn8onOjUKWUUnXoMIm/S7eeVTX++Gh3mKNRSqn2q8MkfldcV5Ipwouf1Tv1Fs1KKVWfDpP498b2xyWGQZLNn95ZGe5wlFKq3eowiX937JEADHNtCW8gSinVzoXy6MW+IjJfRFaLyCoRubmOMiIij4vIBhFZLiKjWyfc+uVEp1FsohkmW9p61UopFVFCefSiH/iNMWaxiCQAi0TkU2PM6hplzgCOcF7jgb87/28zZQFhjUlnqGtrW65WKaUiTqM1fmPMTmPMYuf9AWAN0KdWsWnAS8b6DkgWkV4tHm0DissDrAxmMEy2IATbctVKKRVRmtTGLyIZwNFA7fsi9AG21/icxaEHh1ZVXO5nq+lJvJSShN6zRyml6hNy4heReOBN4BZjzP7mrExEZojIQhFZmJPTslfXFpcHyDVJAHQTfRKXUkrVJ6TELyJebNJ/1RjzVh1FsoG+NT6nOdMOYox5xhgz1hgzNjU1tTnx1uuUoT3IwSb+VE38SilVr1BG9QjwLLDGGPO/9RR7D7jKGd0zASgwxrTpzfEvGpPGA1dOBaAbBZz0P1+05eqVUipihDKqZxJwJbBCRJY60+4A0gGMMU8BHwBnAhuAYmB6y4faMBEhKrknYJt6NuYUEQwaXC69b49SStXUaOI3xnwDNJg9jb0P8q9aKqjmMr4U/MZV1cb/2vfbOFDq57rJA8McmVJKtR+h1PgjhhEXeSTSFdv3/Efn1g2a+JVSqlqHuWUDgFuEXJOko3qUUqoBHSrx90zyQXx3TfxKKdWADpX4AXr06sso1yZ+7XkD0EcwKqVUbR0u8YvXB8BNnndIQe/Lr5RStXW4xF+WfkLV+0QpDmMkSinVPnW4xO8/aho/K/8NQNU9e/wBvWmbUkpV6nCJP9rrosDEAZAkNvGXa+JXSqkqHS/xe1wUEA9U1/jL/Zr4lVKqUodL/FGe6hp/ZRt/mSZ+pZSq0uESf4zXzcD03kB1jb+sQhO/UkpV6nCJX0S4+7/GUmY8DIivAKA8EAhzVEop1X50uMQPcGTPRLxxKUzoY29FVKo1fqWUqtIhEz+AKzaF6Ap7szZt41dKqWodNvHjSyLKSfw6qkcppap14MSfjNdJ/Kt3NusRwUop1SGF8ujF50Rkj4isrGf+ZBEpEJGlzuvOlg+zGWKScZfbhH/v+6vDHIxSSrUfodT4XwBOb6TM18aYUc7rnsMPqwX4kquaegA+WbUrjMEopVT70WjiN8Z8Bextg1hali8JV1kBA7rGADDj5UU8+83mMAellFLh11Jt/BNFZJmIfCgiw1pomYcnvjuYIFf551B5X35t8lFKqZZJ/IuBfsaYkcD/Ae/UV1BEZojIQhFZmJOT0wKrbsCoy2DIOVxd9iqZojV9pZSqdNiJ3xiz3xhT6Lz/APCKSLd6yj5jjBlrjBmbmpp6uKtuWHQCnPM4ftyc7f6uddellFIR5LATv4j0FBFx3o9zlpl3uMttEbFdWB492kn8trnHGH0co1KqcwtlOOcs4FvgKBHJEpGficgvReSXTpELgZUisgx4HLjEtKPsOt91LGmSyxDZBkBFoN2EppRSYeFprIAx5tJG5j8BPNFiEbWwN/cN5Dc+mOhazZpAP/zBIMZvMAZ8Xne4w1NKqTbXca/cdeygG1uD3ZnosiN6KgKGSQ98zuA/fRTmyJRSKjw6fOIH+DY4lPGuNbgIsm7XAXILy8MdklJKhU2HT/zv33gcXwZHkijFTHUt4eKnvw13SEopFVYdPvEP6ZXIJ8GxbAumcr3nXSpH9yilVGfV4RO/2yUEcPNs4ExGuzYwUHaEOySllAqrDp/4K30bHArACNkU5kiUUiq8Ok3i32h6U2KiGO6qvn3DH99ZEcaIlFIqPDpF4l9592k889PxrDIZZNZI/K98tw2K98L7v4ayA2GMUCml2k6nSPzx0R66xkezItifYbKFaMo50/Udw2UTrJkLC58ld+mH4Q5TKaXaRKdI/ACJPg/LggOJkzIe9D7D36IeZ270Hylb8CwAc+a+x/ebI++xAwcp2Qfbvw93FEqpdq7TJP7YKA+fBUdTbKI5z/0fNgZ7ETBC9J5lABznWoHvPw9DYR23i/aXwTePQlFu6wZZkA37dzb/++/dCM+fYZuvlFKqHp0m8XeLjyJzQBqru5wEwAuB01hqBgEQNEKmawsj1v8NXrvIJs787RAM2i9/+Dv47L9h2azWC9BfDi+cCW9cVX+Z0v2HHhiMgfJiyF5sm62Cftj4eevF2Zi9m+2BsrX85wl451d2u5VSzdJpEr/H7WL2jImUHPMrPgmM4d3AJL4IjATg30H70LDC2DTYtRL+eiQ8mmlf834Di16wC1n/KTx/Fix9LbSVZi2CTV/CC2fDf/6vevp3f4dVbx9cdslLsG8LZP1QXWPfuwl2Lq8uM/dmePaUg5PeezfA46Pg4zsgJsW+Nnx2aByLXzo0WQacg0RJfmjbU58dS+zBJ3cDPHEMfPnQ4S2vPtmL4dM/wdJXYMkrrbMOpTqBRu/O2dGUJR/BjIrfAPBW4HhGujZyt/8qrjEf0eW4PzFtAPD9P6DrQFj+T/hhJgyYArFdYOWbdiFbv7EPehlyjj0r2PoNZC+C9GMh7Rj47E5ISoePb7c1cICt/4HuQyF3PXx0m522ayUceTokpcGXD0NCbziwA9a+b6fN+RmU5sNJd8K4X8CPH0FFMexcBr1H2Rp+ZQIs3A2n3Au7lsOPH9szg/judl3fP2PLdDsKfEnw9V9B3BAdb7fPGwfXfAS9RoT+Q5YXQ1SsXddrF8OAyeCOhmAFLH8DpvwBXC1Yr8heBK9dAvE9IKmvPQMbdp7dD0qpJpFw3Tp/7NixZuHChW2+3i/W7eHq53+oc97/XDSSC8akVU8oL7LNO5kXwIo58MFvoesg8Pjs8M+h58LaebZmXinjeNjytX0fFQ9nPAipQ2wTzv4sO73/iTYpr/in/ezy2oR59Tx48RwwThNTXHfoO84eCFKHQM4aO33IOdD7aPjqf6DbETD4LHsQ+NknsGc1vHiuTYjx3e1BYtwvYMnL0O9Y+9lfDoEy8JfCkWdA9kJI6AVDzoWy/TDqcug+2K6rsn9j0fNwwUzIOA4WvWjPhI48DbZ9B+KCoj22fNox9qzlwudh6Hktk/yzFsHL59mzmcv/CWWFMHOqPSAe/5vDX35DCvfYpr6JN0DaWKgoga8ettvWlAOlUi1ERBYZY8Ye1jI6W+L/94ZcLp+5oM55R3SPZ/2eQjb8+Qw87loJa+dyePp4W5PtMwZe+S9weaD/CTDiElvjnX+fbVLpPhR8yTDiYhg73X6/JN/W2KMTYdDJ4ImC3ashbz3M/wukDoaLX7Q1/13LYORlNunHdrU19o//YGvY8T0hd51dZrcj4adzIaGnbcaxD0KzTS+f/Anyt8LUO2HERfDWL2D5bPv9n861iXrBU3DWI7DpC3jrWvtdl8ceeIacYxNs/jYbo8dnk/qgk21tu/tQe5BJHWLjPrDLJuYu/eGxUVCca3+b8/5uz16aa8cSeGmaXfbVH0BSHzv91Ysh63u4dRVExTV/+Y355I+2mc7ltTGkj7cHWXc0nPEAjJle/bsr1QbaJPGLyHPA2cAeY0xmHfMFeAw4EygGrjbGLG5sxeFK/As25fGTZxp+Bu+Ku04lwec9dMaa92HgVJuA139ma8U1k1owAN8+AYNOgR5DQw/KGPtqqHa8a4WtbZYXwpZ/w/hf2oNCqDXqnHXw78dhyh3VybOmA7vswaq8yDYFLXkVUtLBEwPH3WLPaj75oy077Hz4r5n2DCaxD7hr/VZFubZZ7JM/2aauHkPhhN/ZM6TsxfbsYsHfbf/C0VfC4LPBEw1x3WDjfHC57UFj1dvw9i8hLhWmfwDJ6dXr2PotPH86nPlXGPfz0H6DUBhjz1jSjrHNbI9kQr9J9oC28XPI/RH6jrdncxv/BSf8P5j6x5Zbv1KNaKvEfwJQCLxUT+I/E7gRm/jHA48ZY8Y3tuJwJf6dBSVMvL/hUS8//OFkUhOi2yiiCFF2AD68DdInwMhLwR1C99C+rbaJac37sHcjjLwEFr9sa+jlhba9vrzInk0ATHsS3vyZPVhM/xBeuQBSMuDS2ZDQ4+BlGwMzT4KiHLh+gT0Yt4TV78EbV8IFz8LulfDNI/DLf0PPTNt5/e71cPr90OtomHuT3b5JN8PUPx16AFSqFbRE4m+0umiM+QpoaGD4NOxBwRhjvgOSRaTX4QTVmnolxZDgs0lr6uDudZYprQi0ZUiRIToBznsSRl8ZWtIHSOlna8PTP7Ady4tfsv0R3hjbqX35HHsAKM23iXvWT+yBwBtjr0co2Qun/eXQpA+2eeXku21T1Mvnw5s/h3U1nqq2Z61tNqsobTzOgiz49knbtLXuAzvt83vhu6dg+EU26QN0G2T7UfqMsWdaZz8Co38K/34MPrvLnomgKY6ZAAAXuElEQVSVF4X22ygVRi0xqqcPsL3G5yxn2iFXIonIDGAGQHp6eu3Zbea6yQN56KN1DOoez+SjUrnz3VUHzb/j7RW8dM04RNtuW0ZsF7jum+p+iIoScEfZJp0Tb4PEXjD4HHjnOtt5PvIntjM8Od2eYdSn//Fw7I2w6CXbFLXin3DJa5B6FLx/C2z71vZfXPm2Xa/ba4fkfv2/cMo90PcY2/T08nlQWmBHSB3YBYlpdmht1yNsB3J93F4493G7Hd8+YV8Zx8MVb9l5C5+FtHHaCazanZA6d0UkA3i/nqae94EHjDHfOJ//BfzeGNNgO064mnoAZn69ifvmrWH6pAz++5xhZNw275AymX0Sefv6SXhrd/KqthEM2JcnKrTyZQfgbxOhoEYd5MjTbYd610H2wrJ+x9pO+rIC24l93K9h1Vv2rGDqH+Cj2+3Zx3/NhD6jIaV/aH0oZQfsVdO+ZDv6afDZ0HMEfPEX8MbCpbNs539DSvfD/h12iG1in7o7jAN+MAHbH6I6rZZo6mmJGn820LfG5zRnWrtVmcz9AXvQS/B5OFDqP6jMyuz9ZO0roX+3VhwxourncttXqKIT4KIXbVNNUY4dtnrRCzD3FjuaafhFtoM7rhtMn2eHZH71kD0AXPGmTcx9x8Py12HI2ba5qUnrfsG+Tx0MH/3eDsE94lR7Bfica+DsR+0w3z1r7DUXR55mR1zlrrNnOrtq3CK833F2eXtWQdZCO3R3wGR7ZlK42w6pzfrBjijyxtjrQZLS7BnO0Gl2VFZsV/u93PX2c/42e1DZ/DX0HG7Pigqy7PDdLgNtp3WvEbb82nmw4V92WVGxdpsSetqDU/ZC2L0KkvvZUWNJabZ5q+cICJTb36K0wJ7x+JLtSKjEXvaixLhutl9n3xbbdxPf024Pxo6Syl5kv2sCdru6DLQHwH1bqq8b8cbY0XA5a23fUGJvu9zyYrv+7kPtgX33anvAT+xlf5/SAvt3EZNihzEn97OVhOhEuw/Ki+01Lr4kOzy5NhO0I9V2r4Yew+zvEZ1gD+xlB2x/UFQ8eH32+x6ffXlj7HeDfjsyrGQvHHWGHSIeRi1R4z8LuIHqzt3HjTHjGltmOGv8b/ywnd+9uZyrJvbjnmmZbMsr5oSH5x9S7rNfn8ig7vFhiFC1mECFre2nHnnwdGPsP3yPz17v0JJ2LLGJpt8ku+5/TIXyA3aEVK8RNhls+sImiMokPfFX9qK//dl2eG+gnIMeE9pnjE2MlVL6w77NB6/XHW2vz6hUufyaXJ7qiwrr40uyTVbrP7VNcuU1blkenQjdh9jtSsmwZylRsfbAIW6btN1Rdh21191SxGWHMvtL7foD5YDYikLltkUn2oOa/YIdUBCXas/oxAXFec7n/XZ7Kg9YpQUN/C6Jtq8qZ529IWLZAdtH5Y21I9f8pfbvLRiw7/2l9mxSxP7ugXKISYbx18HE65u/+W1R4xeRWcBkoJuIZAH/DXgBjDFPAR9gk/4G7HDO6YcTUFuYdnRv1uzazy0n2WTgi6r7dF47eTsAt/fQpA/2H2NyK/Uz9T66+n3qkXDTEtsHkXqU/YcP8OMn9jqE+B62Zl3z4HPEqbD6XTst8wJ7drL+E3shXvfB9ixi6p9sn0L+Nhhwoq2pj5thDwaeGNi+wF5nkT7BJqLE3vZAd9RZtszezba27o6yteduR9izke5DbZwud/UtPvZn20TpjbPDWus6EystsDVef6ldv4hNjCV77ZlFTBdb4y7Os7HsWWMTYVyqXV5Fia2hJ6fbJFleCHkb7LJT+tvEXF5ky8V1sy+wV85XFNntQOw1J8bYWvm+LbbPJu2YgwckBIN2+b7Eg69/6UQ63QVcddlbVM7oez89ZPobv5hIIGiYOLBrGKJSSqlDtclwzs7A7ar7iH/x099y6T++47tNeW0ckVJKtR5N/DR+prd2535mf7+NjNvmsUAPAkqpCKeJH4jxunEJ3HveIX3XANw1dzW3vWVHXbz2/ba2DE0ppVqcJn7s8M5N95/FlRP6NVq2oKSiDSJSSqnWo4m/lhunDuKNX0zk6mMz6pyfX6yJXykV2Trdg1ga85tTjwIgaAwv/GfLIfP3l1QQCJp6O4SVUqq90xp/PbzuuhP7ptwiBt7xAbsK7M2/duSXMPnh+WzfW9yW4SmlVLNpjb8ejd2gbcL9/6JPcgwnD+nOlrxi/rkoi8zeiZx4VCrRnibcakAppdqY1vjrkejcunlgahx/vWhknWWy80t48dutAKzMLmDGy4t4+KN1bRajUko1h9b46zGoewIzrxrLxIFdiYv2MDItiae/2sScRVl1lt+RXwLAlryDm3wybpvHZePTufvcYXqnT6VUu6CJvwEnD61+AMgRPRJYmV3/DZzW7rI3svpszW78gSA3z15KXpG9YdZrC7axM7+E56dX37vu1QVbmTigKwNSD/8mcAUlFbiEuh8XqZRStWgVtAlioqrb7sf171Jvuef+vZl5K3by3abqB5fNX5dDxm3zuO6VRewqKOUPb69k6v98Wef33/hhOxtzCuucl7WvmIzb5rFoa/WyR979CeP+/K+mbo5SqpPSGn8T/PWikfzh7RU8dsnR9Ej0sX73AU555KtDyr23bEe9y/hw5S76dql+PuywOz/iuztOqqqtVwSC/O7N5YC9piApxsuA1DimDu7B6z9sq2pKmv39dsb0qz74lDh3EjXGhPXJYXOX7WBkWjLpXVvoGbhKqRanNf4mGJgaz+wZE+mRaB8OfkSPBJ66YjRH9ji4uWZl9v66vl7lma82Vb0vKg9w7YsLydpXTDBoePrLjVXz/u/zDdw3bw3XvLCQd5dm8/s3V/D3L+z8/aUVVASCbMmtfsZruT/I8Ls+4b73V5Nx2zzmOgegVTsKqs4Q/rMhl2e/qXUfd4cxhoLDuECtpDzAjbOWcNHT/2n2MpRSrU9vy9xCcgvL2JRTxOzvt/HWEvsAsnk3HcdZj3/TZjH836VHc+OsJQdNu++8TP74zkoA1t57OoP/ZB9IfurQHjx95Ri+27SX8f27sHJHAec+8W8AnrjsaBZs2svtZw4mNurgk8Jyf5Aoj4sftuzl4Y/X8crPxhPlcVERCHLEHz6sKrflgbOq3u8rKiclLsRHKDZi/to9JMV6ESBoYEy/lBZZrlKRoiVuy6yJv4UFg4bHP1/PjvwSHrpwJB+t3MUvX6l+ctLUwd35fO2esMT22s/Hc9k/FlR9jva4KPPX/5SkAalxPHnZaI7oHs/ibfkUlFTw85cWMveG47hp9hI25xbRIzGawT0TyS+pYNn2/KrvVib+ect38qvXFvPeDZMYkZZ82NtQ+/nIG/58Bh4dLaU6kTZL/CJyOvAY4AZmGmMeqDX/auBhqp+1+4QxZmZDy+yoib82YwxvLNyOx+UiO7+EKUd156Kn/0NpRXXCvWHKIJ6Yv6Hq88+P788/vrbNMTFed1X7/fd3nERBSUWd/Qqt6a5zhnLX3NVEuV2UB4JkdI09ZNhqbVEeF2dm9uSdpba5qU9yDOWBIK9eO570LrEEgoa46EO7mFbtKGBbXjGnZ/bkd3OWsyK7gFk/n8AZj33Nrv2lh5R/6MIRTBrUjd5Jvjbr2ygpD/DWkiwuPSYdl966Q7WxNkn8IuIGfgROAbKAH4BLjTGra5S5GhhrjLkh1BV3lsRfn2e+2si+4gqunzyQBJ+XA6UVbMwp4tuNeVw3eSDn/+3fLNmWz+b7z2RfcQV7i8oY1D0BgLveW1XnfYQixZE94skvriC/uIL/d9pR/PyEARSW+fF5XAxymovqaraqy+j0ZBZvy+fe8zJDurtqQyoCQXbvLyUt5eCO6coO809W7WLGy4vo1zWWrZWd7DMmMGGAPqFNtZ22SvwTgbuMMac5n28HMMbcX6PM1Wjib1FFZX4qAkGSYw9tGy8u97NsewG5hWX0TPJxTEYXjDHMXb6TRJ+Hq5//gYGpcWzMsR2/xw3qRlpKDP/ZmMe2vcVEeVyM79+Fhy4cQc9EH0EDgaDhqS838uhnPzK+f1e+bcYDZwZ1j+fMzJ48/vmGxgvX8ORlo/nVa4tDKtsnOYZs52K52iYM6MJfzh8e8rURlX/7lWcKd7y9gtcWbGPFXacS43WzdHs+v5uznOz8Eob3SWJFdkGdTWOv/Xw8c5ft4N5pmXjcLowxPPzxOs4c3ovMPkkhxaJUqNoq8V8InG6Mudb5fCUwvmaSdxL//UAO9uzgVmPM9jqWNQOYAZCenj5m69athxO7qse+onJ8XjeFZX68bqnz4NGYrH3FHPfg/HrnJ/o89Osax0lDuvPoZ+uB6nb9G15bzKRB3Vi6LZ/XF27nZ8f1r3ckUSgSfR7mXHcsOwtKOfHIVPYWlfPIpz/y8neN//1Ee1y8e8MkjuyegMslVASCeN0uNuwp5OT/tddRvPurSXyzIZfnvtlMXlE5v5oykH98tZnyQP39H3V57drx/L85y7l0XF/++smPxEd7WHn3ac3a5lCV+QN6b6hOpj0l/q5AoTGmTER+AfzEGDO1oeVqjT8ybMwpZEC3OJZuz8clwoi0JF7/YTvH9O/CQKdmvWHPAdbtKuSsEb0O+q4xhqCpfqbxj7sP8NqCbfz61CPZkV/Cxj1FLNy6l8Vb99G3Syy9k2NI7xLLk/M3sLOglDevO5bvNuVx7fH960xuWfuKefnbrfTtEktFIMjdc1cfUqa1vHjNOH763PeNlvvd6UdxyTHplPuD5BaWEQgaMvsk4XYJ5f4gZf5A1TUcWfuK6Znoa7CzuuZ1GvPX7mH6Cz/w4c3HM6h7PP6AYff+UjK6xbXMRtbw6erd/GvNbh64YESLL1s1Tbtp6qlV3g3sNcY0eI6riV/VZ29RObsKShnaO7FJ36sIBHl7cTYfrtzJNxtyqQgYxvZLYeHWfY1+1+0SAsHqfwsXj01jylHdyc4vISnGy+mZPXGJsH5PIVvzipg2qg95hWWICB+t3GU7oZv5WM4ucVHkF5dTufrKDv2RfZPJ7J3Iqwu2cerQHnyyejfpXWLxeV2UVgTZVs+twI/JSCHa42ZfcTmnDevJJ6t34XW76JMcw7kje1MeCDKgWzw78kv4YMVOfnJMX0r9Qd5YuJ3x/btwoNTPsu35/PGsoRgMhWX+qmHJt58xmHNH9cbrdlFY6icxxotLwOd1s31vMYO6x7N7fxlb8ooIBg1ej4vM3kn4vNUHs/0lforK/fROjjko7nW7DpDeJZaYKDfGGDbsKWTJ9nzG9EuhS2wUybFeRISS8gAiUFjmp2tcFCJCcbmfXQWlh3ULlE05hfTvFld1YC0q8+MPGHxRLqI9borL/RhD1aCEvUXl+INBuif46l1mdn4JvRJ9Bw0C+GZ9LiP7JjX7Fittlfg92Oabk7Cjdn4ALjPGrKpRppcxZqfz/nzg98aYCQ0tVxO/akuVNWVjDGX+IGt3HWBzbiEl5UEuG58O2IvivtuYx5TB3Zt9Q73SigDz1+7hhy376J3sY2NOITkHyvhsTfUQ3rSUGLL21d1P4XULFYHwDLFuTV63UJlq/M4RLiHaA04+tGc/tmktLspNcUWA2qkpIdpDtNdFXlF51by4KDdF5YGD1lN57UkgaKgIBHGJkODzUH1cN+wtKqdLXDSVu7mw1F+1nC5xUQSNcfrZ7JcqD7YAKbFeXCLkFZUD0D0hGhEQhKAxlJQHiHbKF5b5cbuEnok+/MEgpRVBCkoquGx8On85f3izfsu2HM55JvAodjjnc8aYP4vIPcBCY8x7InI/cC7gB/YC1xlj1ja0TE38qrMLBA0uAWPA5ZKDOpv9gSD7iiuIjXLjdtmkeaC0goqgcWq54HW5KCipoLDMjz9oyNpXzJE9EtiRX8KmnCIGpMYRE+Vm7c4D+IMGj0vokejjx90H6BYfTWKMh/0lfvJLyukWH020x/Z95BaW0y0+igSfp+pgmRIbxZqdB/C4hUSfB3/QNuOVlPurYs4vLifGa5vkFm/Lx+MWPC4XA7vH4XJq0YKtKW93Yq1p8bZ8xmWkEAiC2wVLt+fTKymGPikxGAOFZRUYA7v3l1FYVsG2vGLG9e9CYoyXzblF5BWWc/wR3Sh1hj973C48LqE8EKSkPFBV6zYGXGJ/f2OwSVvgmw25nDS4B2X+IMGgYf0ee+PFPimxdE+IZtveYvIKyziqZyLFznYXlFTQM9FXdSASwbmg0RDtcbG/pMLZt3ab/EFDQrSHW085sll9b85vrRdwKaVUZ9ISiV8veVRKqU5GE79SSnUymviVUqqT0cSvlFKdjCZ+pZTqZDTxK6VUJ6OJXymlOhlN/Eop1cmE7QIuEckBmnt7zm5AbguGE266Pe2bbk/71tm2p58xJvVwVhC2xH84RGTh4V651p7o9rRvuj3tm25P02lTj1JKdTKa+JVSqpOJ1MT/TLgDaGG6Pe2bbk/7ptvTRBHZxq+UUqr5IrXGr5RSqpkiLvGLyOkisk5ENojIbeGOp5KI9BWR+SKyWkRWicjNzvQuIvKpiKx3/p/iTBcRedzZjuUiMrrGsn7qlF8vIj+tMX2MiKxwvvO4VD4jrnW3yy0iS0TkfedzfxFZ4MTwuohEOdOjnc8bnPkZNZZxuzN9nYicVmN6m+5LEUkWkTkislZE1ojIxEjePyJyq/O3tlJEZomIL5L2j4g8JyJ7RGRljWmtvj/qW0crbc/Dzt/bchF5W0SSa8xr0u/enH1bL2NMxLywTwDbCAwAooBlwNBwx+XE1gsY7bxPwD6ucijwEHCbM/024EHn/ZnAh9iHEk0AFjjTuwCbnP+nOO9TnHnfO2XF+e4ZbbBdvwZeA953Pr8BXOK8fwr7tDWA64GnnPeXAK8774c6+yka6O/sP3c49iXwInCt8z4KSI7U/QP0ATYDMTX2y9WRtH+AE4DRwMoa01p9f9S3jlbanlMBj/P+wRrb0+Tfvan7tsFYW/MfWiv8sU8EPq7x+Xbg9nDHVU+s7wKnAOuAXs60XsA65/3TwKU1yq9z5l8KPF1j+tPOtF7A2hrTDyrXStuQBvwLmAq87/wDyq3xh1y1P4CPgYnOe49TTmrvo8pybb0vgSRsopRa0yNy/2AT/3ZswvM4++e0SNs/QAYHJ8pW3x/1raM1tqfWvPOBV+v6PRv73Zvzb6+hOCOtqafyj71SljOtXXFOtY4GFgA9jPMgemAX0MN5X9+2NDQ9q47prelR4HdA0PncFcg3xvjriKEqbmd+gVO+qdvZWvoDOcDzYpuuZopIHBG6f4wx2cBfgW3ATuzvvYjI3T+V2mJ/1LeO1nYN9swDmr49zfm3V69IS/ztnojEA28Ctxhj9tecZ+whOSKGUYnI2cAeY8yicMfSQjzY0/C/G2OOBoqwp/lVImz/pADTsAe03kAccHpYg2phbbE/2mqfi8gfAD/wamuvKxSRlvizgb41Pqc509oFEfFik/6rxpi3nMm7RaSXM78XsMeZXt+2NDQ9rY7prWUScK6IbAFmY5t7HgOSRcRTRwxVcTvzk4A8mr6drSULyDLGLHA+z8EeCCJ1/5wMbDbG5BhjKoC3sPssUvdPpbbYH/Wto1WIyNXA2cDlzoEGmr49eTR939avtdogW+OFrbVtwtZyKjs+hoU7Lic2AV4CHq01/WEO7kh6yHl/Fgd3Vn3vTO+CbYtOcV6bgS7OvNqdVWe20bZNprpz958c3MF0vfP+VxzcwfSG834YB3dibcJ2YLX5vgS+Bo5y3t/l7JuI3D/AeGAVEOus70XgxkjbPxzaxt/q+6O+dbTS9pwOrAZSa5Vr8u/e1H3bYJyt+Q+tlf7gz8SOmNkI/CHc8dSI6zjsKeNyYKnzOhPb1vYvYD3wWY0/SgGedLZjBTC2xrKuATY4r+k1po8FVjrfeYJGOnBacNsmU534Bzj/oDY4f4jRznSf83mDM39Aje//wYl5HTVGurT1vgRGAQudffSOkygidv8AdwNrnXW+7CSRiNk/wCxs/0QF9ozsZ22xP+pbRyttzwZs+3tlTniqub97c/ZtfS+9clcppTqZSGvjV0opdZg08SulVCejiV8ppToZTfxKKdXJaOJXSqlORhO/Ukp1Mpr4lVKqk9HEr5RSncz/B7u+sFf0oLPZAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "complete - train time: 25275s, best epoch: 187, best loss: 0.926487, best accuracy: 79.50%\r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 432x288 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle\n",
    "import paddle.fluid as fluid\n",
    "from paddle.utils.plot import Ploter\n",
    "import numpy as np\n",
    "import time\n",
    "import math\n",
    "import os\n",
    "\n",
    "epoch_num = 300   # 训练周期，取值一般为[1,300]\n",
    "train_batch = 128 # 训练批次，取值一般为[1,256]\n",
    "valid_batch = 128 # 验证批次，取值一般为[1,256]\n",
    "displays = 100    # 显示迭代\n",
    "\n",
    "start_lr = 0.00001                         # 开始学习率，取值一般为[1e-8,5e-1]\n",
    "based_lr = 0.1                             # 基础学习率，取值一般为[1e-8,5e-1]\n",
    "epoch_iters = math.ceil(50000/train_batch) # 每轮迭代数\n",
    "warmup_iter = 10 * epoch_iters             # 预热迭代数，取值一般为[1,10]\n",
    "\n",
    "momentum = 0.9     # 优化器动量\n",
    "l2_decay = 0.00005 # 正则化系数，取值一般为[1e-5,5e-4]\n",
    "epsilon = 0.05     # 标签平滑率，取值一般为[1e-2,1e-1]\n",
    "\n",
    "checkpoint = False                   # 断点标识\n",
    "model_path = './work/out/hs-resnet'  # 模型路径\n",
    "result_txt = './work/out/result.txt' # 结果文件\n",
    "class_num  = 100                     # 类别数量\n",
    "\n",
    "with fluid.dygraph.guard():\n",
    "    # 准备数据\n",
    "    train_reader = paddle.batch(\n",
    "        reader=paddle.reader.shuffle(reader=paddle.dataset.cifar.train100(), buf_size=50000),\n",
    "        batch_size=train_batch)\n",
    "    \n",
    "    valid_reader = paddle.batch(\n",
    "        reader=paddle.dataset.cifar.test100(),\n",
    "        batch_size=valid_batch)\n",
    "    \n",
    "    # 声明模型\n",
    "    model = SSRNet()\n",
    "    \n",
    "    # 优化算法\n",
    "    consine_lr = fluid.layers.cosine_decay(based_lr, epoch_iters, epoch_num) # 余弦衰减策略\n",
    "    decayed_lr = fluid.layers.linear_lr_warmup(consine_lr, warmup_iter, start_lr, based_lr) # 线性预热策略\n",
    "    \n",
    "    optimizer = fluid.optimizer.Momentum(\n",
    "        learning_rate=decayed_lr,                           # 衰减学习策略\n",
    "        momentum=momentum,                                  # 优化动量系数\n",
    "        regularization=fluid.regularizer.L2Decay(l2_decay), # 正则衰减系数\n",
    "        parameter_list=model.parameters())\n",
    "    \n",
    "    # 加载断点\n",
    "    if checkpoint: # 是否加载断点文件\n",
    "        model_dict, optimizer_dict = fluid.load_dygraph(model_path) # 加载断点参数\n",
    "        model.set_dict(model_dict)                                  # 设置权重参数\n",
    "        optimizer.set_dict(optimizer_dict)                          # 设置优化参数\n",
    "    else:          # 否则删除结果文件\n",
    "        if os.path.exists(result_txt): # 如果存在结果文件\n",
    "            os.remove(result_txt)      # 那么删除结果文件\n",
    "    \n",
    "    # 初始训练\n",
    "    avg_train_loss = 0 # 平均训练损失\n",
    "    avg_valid_loss = 0 # 平均验证损失\n",
    "    avg_valid_accu = 0 # 平均验证精度\n",
    "    \n",
    "    iterator = 1                                # 迭代次数\n",
    "    train_prompt = \"Train loss\"                 # 训练标签\n",
    "    valid_prompt = \"Valid loss\"                 # 验证标签\n",
    "    ploter = Ploter(train_prompt, valid_prompt) # 训练图像\n",
    "    \n",
    "    best_epoch = 0           # 最好周期\n",
    "    best_accu = 0            # 最好精度\n",
    "    best_loss = 100.0        # 最好损失\n",
    "    train_time = time.time() # 训练时间\n",
    "    \n",
    "    # 开始训练\n",
    "    for epoch_id in range(epoch_num):\n",
    "        # 训练模型\n",
    "        model.train() # 设置训练\n",
    "        for batch_id, train_data in enumerate(train_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in train_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = train_augment(image_data)                                                        # 使用数据增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "\n",
    "            label_data = np.array([x[1] for x in train_data]).astype(np.int64)                        # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                             # 转换数据类型\n",
    "            label = fluid.layers.label_smooth(label=fluid.one_hot(label, class_num), epsilon=epsilon) # 使用标签平滑\n",
    "            label.stop_gradient = True                                                                # 停止梯度传播\n",
    "\n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label, soft_label=True)\n",
    "            train_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            # 反向传播\n",
    "            train_loss.backward()\n",
    "            optimizer.minimize(train_loss)\n",
    "            model.clear_gradients()\n",
    "            \n",
    "            # 显示结果\n",
    "            if iterator % displays == 0:\n",
    "                # 显示图像\n",
    "                avg_train_loss = train_loss.numpy()[0]                # 设置训练损失\n",
    "                ploter.append(train_prompt, iterator, avg_train_loss) # 添加训练图像\n",
    "                ploter.plot()                                         # 显示训练图像\n",
    "                \n",
    "                # 打印结果\n",
    "                print(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\".format(\n",
    "                    iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "                \n",
    "                # 写入文件\n",
    "                with open(result_txt, 'a') as file:\n",
    "                    file.write(\"iteration: {:6d}, epoch: {:3d}, train loss: {:.6f}, valid loss: {:.6f}, valid accuracy: {:.2%}\\n\".format(\n",
    "                        iterator, epoch_id+1, avg_train_loss, avg_valid_loss, avg_valid_accu))\n",
    "            \n",
    "            # 增加迭代\n",
    "            iterator += 1\n",
    "            \n",
    "        # 验证模型\n",
    "        valid_loss_list = [] # 验证损失列表\n",
    "        valid_accu_list = [] # 验证精度列表\n",
    "        \n",
    "        model.eval()   # 设置验证\n",
    "        for batch_id, valid_data in enumerate(valid_reader()):\n",
    "            # 读取数据\n",
    "            image_data = np.array([x[0] for x in valid_data]).reshape((-1, 3, 32, 32)).astype(np.float32) # 读取图像数据\n",
    "            image_data = valid_augment(image_data)                                                        # 使用图像增强\n",
    "            image = fluid.dygraph.to_variable(image_data)                                                 # 转换数据类型\n",
    "            \n",
    "            label_data = np.array([x[1] for x in valid_data]).reshape((-1, 1)).astype(np.int64) # 读取标签数据\n",
    "            label = fluid.dygraph.to_variable(label_data)                                       # 转换数据类型\n",
    "            label.stop_gradient = True                                                          # 停止梯度传播\n",
    "            \n",
    "            # 前向传播\n",
    "            infer = model(image)\n",
    "            \n",
    "            # 计算精度\n",
    "            valid_accu = fluid.layers.accuracy(infer,label)\n",
    "            \n",
    "            valid_accu_list.append(valid_accu.numpy())\n",
    "            \n",
    "            # 计算损失\n",
    "            loss = fluid.layers.cross_entropy(infer, label)\n",
    "            valid_loss = fluid.layers.mean(loss)\n",
    "            \n",
    "            valid_loss_list.append(valid_loss.numpy())\n",
    "        \n",
    "        # 设置结果\n",
    "        avg_valid_accu = np.mean(valid_accu_list)             # 设置验证精度\n",
    "        \n",
    "        avg_valid_loss = np.mean(valid_loss_list)             # 设置验证损失\n",
    "        ploter.append(valid_prompt, iterator, avg_valid_loss) # 添加训练图像\n",
    "        \n",
    "        # 保存模型\n",
    "        fluid.save_dygraph(model.state_dict(), model_path)     # 保存权重参数\n",
    "        fluid.save_dygraph(optimizer.state_dict(), model_path) # 保存优化参数\n",
    "        \n",
    "        if avg_valid_loss < best_loss:\n",
    "            fluid.save_dygraph(model.state_dict(), model_path + '-best') # 保存权重\n",
    "            \n",
    "            best_epoch = epoch_id + 1                                    # 更新迭代\n",
    "            best_accu = avg_valid_accu                                   # 更新精度\n",
    "            best_loss = avg_valid_loss                                   # 更新损失\n",
    "    \n",
    "    # 显示结果\n",
    "    train_time = time.time() - train_time # 设置训练时间\n",
    "    print('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}'.format(\n",
    "        train_time, best_epoch, best_loss, best_accu))\n",
    "    \n",
    "    # 写入文件\n",
    "    with open(result_txt, 'a') as file:\n",
    "        file.write('complete - train time: {:.0f}s, best epoch: {:3d}, best loss: {:.6f}, best accuracy: {:.2%}\\n'.format(\n",
    "            train_time, best_epoch, best_loss, best_accu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "infer time: 0.020907s, infer value: cattle\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMgAAADFCAYAAAARxr1AAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAGylJREFUeJztnWlsXOd1ht9zZ+Um7pQoURIteZUVW04c17HjVFmcOGkAJ0VhJGgDA3WWAgnaoPljuECbAv2RAk2CoghSJKhrB0jjpHFSu47T2HGdOnYTWbIta7MsibQW7hSX4XAWznK//phhyuF7eDniSBQpnwcQxDm8c+937/DMved857yfOOdgGIaOd7kHYBhrGXMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAIwBzGMAGpyEBG5R0TeFJFTIvLgxRqUYawVZKUz6SISAnACwN0ABgDsB/Bp59yxpd7T0dHhent7V3Q8Yzn4c8zPzZEtlU6TrbFpg7rHcDhc+7BWgK/YisWCuu3cXJZsoTB/7+dylduNjYwjMZ2U5cZSyxW4DcAp51w/AIjIYwDuBbCkg/T29uLAgQM1HNJYkiI7w8jZPrLte/lVst31oXvUXba1d9Q+rmUoKrZ0ka3J2Un1/f19b5Cttb2BbGfPnqx4/eefe6iq8dXyiLUFwLkFrwfKtgpE5PMickBEDoyPj9dwOMNYfS55kO6c+45z7lbn3K2dnZ2X+nCGcVGp5RFrEMDWBa97yrYLwqqJLxxfeR6X/BTZkmP9ZHv+yZ/wdkl+jgeAP/nsZ9mofF6+r3yGylevAz/y55X3Dg2fJdvk9IA6xuFzR8nWf/I82RIzlddnLptS97eYWu4g+wFcIyJXiUgUwKcAPFnD/gxjzbHiO4hzriAiXwLwCwAhAA8759idDWMdU1Mezzn3NICnL9JYDGPNYTPphhHA5ZkJWgaRZedv3jZoKQxPlNmDYpLfm+G0eoOfI9vE8Ih67NGRUbKFhL9Tm1uayRaJRsjmK0G6czwtGOa3Il/MqGNs39hOttFxDtKH+4Yq95fPq/tbjN1BDCMAcxDDCMAcxDACMAcxjADWZJC+WmhVo87nor/CFAd9mcQsvzfKRXIbtmzWD64Eu6IErJ7Ps+Yzw+fIdvrIb8n21hvHeX9eVNkfz1wDwK+efpxsrZu3ku2OO+/iN4e5QnhiOkG2uVlOEGSzY2RzBU5CAMDYJFcLTE3z5+X8xde7ukSQ3UEMIwBzEMMIwBzEMAIwBzGMAN7WQTp8npE+f4oD27FXXiRbepIDzpEcf99ce9de9dDX3Hwr2bwIfxyHjx4m22vPP0+2pBK4z4zxTHgkHCNbdmKIbADw/M/OkO2G3/8I2d7zvg/yPud4xn5qjPfXv59L+UaHuBOyffs2dYxpn8vW82m+jlGvq+K1VPmnb3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIoKYsloicBpBESd6o4Jzj1MwaxmW5rGTiTc6gYHqGTG0hRcjM48xN/wvPqscOOy51iG/mTM33fvyfZDt64CDZdrRymUubx2NsUDJlxZDSgAGg/wRnt1488WOydffcSLa7bruBbOPH/5dsrz/zU7LNTbMARWpwlzrG+l3vYlsd63k1XdVa8Toaq04+4WKked/vnOPiF8O4ArBHLMMIoFYHcQCeEZFXROTz2gamrGisZ2p1kPc6594J4KMAvigi71u8gSkrGuuZWmV/Bsv/j4nIT1EStH7hgnZyGfUZvCj3RjR2cf/G+MBbZMuOs9JfQ5T7OWay+gke/61SvtK6nWzPPPMSb5fk3ogmr5ttrXGypeY4cD9+VhdtGEmxZMTABAfQ33/kX3m7g11kS59j4fKGIpeKxOq4HGYuxar0ALC9kQNyb+PVZMtK5Wcd0pQhFFZ8BxGRBhFpmv8ZwIcBHFnp/gxjLVLLHWQjgJ+WJXrCAP7NOfdfF2VUhrFGqEV6tB/AzRdxLIax5rA0r2EEcPn7QTTpwGoD96VWTqjy/U5ZYmzTO/immJ+dJlvf2TfJlp7kNHYuVqce+8QJXhkp1cjqgeE8n+TMBK+2lFBWVYpv58B9ZoqD7ENn9CB9PMdJjKZmVlE8e+p1su2b5CUVrungwDga4fObnmNbU5d+HYeHuA9mQ30bH6dtkQKjVLfsht1BDCMAcxDDCMAcxDACMAcxjAAue5CuxUpKJfgS772A9Q2VJRVEWR8vEuPZ5y233cn7UyZih1/lWe8eRYkQACbOs2DEoX2vka0uzIF7RxMHz3vv4jH+3s1cIv5P3/oW2ZIZLtMH9GuhKRymlVnu2FZelsB3HLiPjnErQbh1I9mkQS9Tev0otyckXmHhje4dOypep2b4uBp2BzGMAMxBDCMAcxDDCMAcxDACWPUgffGi85qH+krwnc1x/3hUmQkH9HX0PG16XQncC8r0fN8kdxRPKQHs3LW7yXbju+5Qx5g/y7PhP/rZL3m7DJeDf/KevWT7w49/mGwnT/HSAGMpTg7kXEgdY8TxttEwb9sU52vR0MJBdSLP59KwkWf7XR0vnTAwri9/UMxwEiOnaAg8/2RloXlymqsjNOwOYhgBmIMYRgDmIIYRgDmIYQSwbJAuIg8D+DiAMefc7rKtDcAPAfQCOA3gPucc11EvwncOc/nKWdu40hc+k+b1/17av49sGxob1ePccuNNZGuqqydbscj92YPjLJb2qxc5eH7rLK/rN6fMSMc296pjLCR5VnnsDC8PMJvka7Gzl2fnw+CAejrBwWrO5yC7UNRWawT8NAfGnuMSglCcP8OJSf5zGB3jZEedsq5jQzMnZBpbeDsAaFKSBnVhTrRs7WipeN13Tl/yYTHV3EEeAXDPItuDAJ5zzl0D4Lnya8O44ljWQZxzLwBYnJO8F8Cj5Z8fBfCJizwuw1gTrDQG2eicGy7/PIKSgIPKQuG48yYcZ6wzag7SnXMOSze/VgjHdZhwnLHOWOlM+qiIdDvnhkWkGwCv/K4gAsiioGpmloPQ/QdfJdvZ4UGyxaIsMAYAnW0sJnZd706yJWYmyHbwIAu6DZ8+RraRsxxwjk3xuRw8zIrmAHBbz/Vk27GJv0Cm2ri/urmDZ5/PDXFf+fAwB6KpJAfPLY16v3dqloP0mSmuANjR1UO2xjj/aaXrFGX5AidKiikeY9HTy9NzrVxWjzAnLJqbK88xHKru3rDSO8iTAO4v/3w/gCdWuB/DWNMs6yAi8gMAvwFwnYgMiMgDAL4G4G4ROQngQ+XXhnHFsewjlnPu00v8itf+NYwrDJtJN4wAVrXc3flAca4ygHpp38u03StHD5Ft5/UcCA6dS6jH+Y+nniPbxz+WJ1vfaRZv6zvHSu5eiMu5J5VZ4cGB02SLF9+tjvEdvb1k+7M//QzZtNnwnS0s3jY0xEmMk4c5uZCc4FR7c7sS6AIoFpQydmXSfUtrE9mcshyd+PzmkMcJ0FBIaUPI8+cHAGlF1C8U5pn9ol+ZDHDQqwcWY3cQwwjAHMQwAjAHMYwAzEEMIwBzEMMIYFWzWEW/iORsZebpv1/gXov2zVwqMpfl/okz/bpsvyiZkZcPserhESVbJsolCWmXKcw9C3s/uIdsXa1cKgIAhTRneXZfdx3ZPGW5goFfcJau7jxnc+5u4nUCN13LvTIHxofJBgDH67j3o7eHy1w6lbKSbJbLVLS+E9/n7JS2fmAsrJfD5JSelajS++NF9LKk5bA7iGEEYA5iGAGYgxhGAOYghhHAqgbp4gkiDZXBUnMbCy8MDrKk/aHXeQn2M6e4/wIAuns4oGvfxCUbvs+9CFOTvM+IEvT37lAC4M1ccpGZ00skclkO0ouK6EPmNJeQpE9zUJ1IcDBfp5SkvHsbl+x0x3jcALBhgvtJwq0snuBH+Dq6IgfaogTkxTwnX0SLpxWxidI+ufejMMf7jHqL329rFBpGzZiDGEYA5iCGEYA5iGEEsFJlxa8C+ByA+eaCh5xzTy+3r1Q6i32vVfZgFBXp/VCIh/VWP/dpDA7qQXpjK4sfFIutZEsmeW09LUi/Sglsuzo5SB8YOEG21rAusx+5kRMJ4QRL+Z87eJRsR2d4GYGfHePtEj4Hqy1xnmX+8HW3qmO8I8oKjudGT5Mt1MwBeaGeezrySvDsfE5MOJ8/fy3wBoBiUZmJd8qM/eKlMqpc33KlyooA8E3n3J7yv2WdwzDWIytVVjSMtwW1xCBfEpFDIvKwiPDzS5mFyoqJKlf1MYy1wkod5NsAdgLYA2AYwNeX2nChsmJzS8tSmxnGmmRFM+nOudH5n0XkuwCequZ9c7kM3jp9uHIAilR9VzuXu4vSZB+v02dXP/SBj5Dt+l07yFacYwXHrjZFOr97G9k623j2ecdWLlff1rlZHaMm7JcY4uUPJmZYtLIfHJg23cRl7IUMVw9MT7LQxRNnWNwBAG7s4tL2q7Rp7hFOLmSaeYbbFbhFoFDgIN3Pc9BfXGLmO53lpEq8QVlbsW7xuC/hTHpZbnSeTwLgOhDDuAKoJs37AwB7AXSIyACAvwGwV0T2oOSGpwF84RKO0TAuGytVVvyXSzAWw1hz2Ey6YQSwquXu0aiPzb2VAV1rB8/s5vMcuH3kD1ihcGKCg0MACMc5SMvleJ+33HIj2bIpDiSHlKUO9tzA793Zu51s0+d12f7hES4lnzw3QDbvat7nXe/fS7asx4HtzCxfnwJfGhx98zAbAZx98xTZukIc3G7wOIHifN7OE95OlJYDpwyysERMnVMUF8NFRZmxUHktnDLbrmF3EMMIwBzEMAIwBzGMAMxBDCOAVQ3Sk6kEXtj/8wpbQQnItvVyufqeO3aR7UyfLhznCQe7k7O8HqFf5Jn4ZIKDxokZDrRffp1npI/38ez64KAepMeV8u3rY7wMgdfAM/EjSln8S/t/TbaCEodGYlxmn5jVVx/ORfj6JOKcDAiHeLs0+PyKSv94aHEZOoCwYssraxkCgCf8HR8K83iyc5XJF19JIqj7r2orw3ibYg5iGAGYgxhGAOYghhHAqgbpsXgYO6+uDETzSrlz1yZtVphLwZMpvdExHOaS7HyR19tLJDmAzitTtm09nDSIxDhID8W5V3z79fp3kF9ke1OYg/xfv8jrKB49yWJyTU3cayOeorqe40qBiWn9OvqO3+8UtfqkokCfyXG/vwjPcEejvJ6gZsso6v4AEI7y34rn8bUtUILAgnTDqBlzEMMIwBzEMAIwBzGMAKrpKNwK4HsANqIU2XzHOfePItIG4IcAelHqKrzPOcfR2gIa6uK4dU9l3/asUpJ97NjrZJuc5l1fv2u3epymxg3amZBlbJwDtXyOt0tO8zJfMymefW5v26TYdMGX2Sx/N8VDHGiH6zlwL+b5mkWFVfLrG1mJ3VMSAdPj59QxtnT3kq01yn8yiUkWzPOFky+xGAffnhK4Fwpcwq61QABAg7LcWlEpIWhorFS69zxddJDGV8U2BQBfcc7tAnA7gC+KyC4ADwJ4zjl3DYDnyq8N44qiGuG4Yefcq+WfkwDeALAFwL0AHi1v9iiAT1yqQRrG5eKCYhAR6QVwC4B9ADY65+ZXchlB6RFMe8/vhOOmJ3mewDDWMlU7iIg0AngcwJedcxUzbM45hyVmXhYKx7W08TOxYaxlqnIQEYmg5Bzfd879pGwendfHKv/PCmeGsc6pJoslKMn8vOGc+8aCXz0J4H4AXyv//8Ry+yr6BSRmKwUQPHBZyEyCsxDHj3PW6FT//6jH6dnGyow37dlJtm3KdnUeZ8CcIgJQVPpYohHutRCuhAAA1Gf4httdz2O8ZQ9naTqaudzjpRdeIltiirWQtf6b8UH9u801cH9K8VoeI5TrowlnxMJ8MTIpLknxi9z7EY3r3+UhRXEzl1GUKRZXGlVXaVJVLdadAD4D4LCIHCzbHkLJMX4kIg8AOAPgvuoOaRjrh2qE416ENolQ4oMXdziGsbawmXTDCMAcxDACWNV+EE+A+milTzqfg6w7b38X2XbuvIFs/WdOq8cZG2fRhukJRSY/wgmC0QwnA1paOHBvauKSDRdRylRmuG8EANoaeN3Dzi7uO0lu5cB//29+Q7aJaVZ/9JVrqyHcKgMAaGvjX7Rt4XKYlPI1G1HEFKLachXC0XImw6U0ztOj6oKizKiddnrRPqu9NnYHMYwAzEEMIwBzEMMIwBzEMAJY1SAd4uCFKoMqL6LI6SsL03ds2kK2G3br6/9lsxzk+Yqq3/D5YbKNJTjYHZsZJdumbg6om5s5qPWX6DuYzfN300T2ZbINTrKwxJFjPGs+l+Vxx+NLRN+LaGjWA+CtbUrvR/Is2bwWPk5LhKsUfHBPhyqw4Pizmk3q1zHkKYG/sgAkTfYvNbO3CLuDGEYA5iCGEYA5iGEEYA5iGAGsapCezc3hxFDlunfNLTwjHctxYLohzs1WrcpsNgDEldJoDywY0NXK5dyRMM9czyR5dj3kOMqbmeby8tFxXnYBABKjrBR5qoPFKnqabyHbH9/3PrId3s/v1dZlbGllEYk5pUwfANw0VwEcOXaIbL2dLBjR3sAl+QVFCXNCKW3fEOHZeqeIOwDAbIIFNeL1/LdSv6FyjJ6nVzgsxu4ghhGAOYhhBGAOYhgBmIMYRgC1KCt+FcDnAMxHsA85554O2lfRL2J6tjIAzxZY1j6mLC2Qb2omW3J2KXU8LmWur+PArbG+m2zxKAecnc1c7p5X1A215RQGTg2pIwwrSxMcGmWFw3PKZPi1US79b1Ouz+YurjTwlPLwbL0eAE9EuFd9CzgxUhfmY9c1KIqQaT6ZfJFVFHNZXqIhn9PXKEwrypyxGB+7tbVS9TIUrk5jpJos1ryy4qsi0gTgFRF5tvy7bzrn/qGqIxnGOqSanvRhAMPln5MiMq+saBhXPLUoKwLAl0TkkIg8LCKqSvNCZcVUgm+nhrGWqUVZ8dsAdgLYg9Id5uva+xYqKzYoVbqGsZapaiZdU1Z0zo0u+P13ATy13H6ikTh6Nl5dYSsoUvWeUq6cyfCs8Ni0rvWrzXxv3c5LE6QVOf5skvfZ2KjMFLcrs/ARFnnbsV1f/6++kQPW/j4u3Y6FlSUMuvmatWzkRMLsLM8yh4ocAO+88WqyAYB/nMvO8wUedzymLEHg8RjbG3m7cITPeeo8Vx+Iz/oBAJDO8FNJOMbbeqHKP3VtvUSNZe8gSykrzsuOlvkkgCNVHdEw1hG1KCt+WkT2oJT6PQ3gC5dkhIZxGalFWTFwzsMwrgRsJt0wAljVcnfnisgVKoPgWIxLrRvquNy5WOCZ1HSClcEBoKGeA79ingPyyTSvexhX1uDTFNp9jwPYdI5n9rs2aeslAvX1HLBu2qSUiBf5OHM+zx63t3EPeCbB28UjnHAI1fN2ABAf54C8boTPx/M58C+Ckx1eiD/rugb+rNMpTshE4rrQW9FxQsYXDtwzhcoqB1/pe9ewO4hhBGAOYhgBmIMYRgDmIIYRwKoG6UW/iFS6cma54LNoWXKWhdpCwkGtCAe1ANDcxPZ0mvcZUZYEkzAH+KksB9/JIS5t12auoZwfADifM+chRR3e95VgV8m6F9PcIhAOcWCbSnNAnczpffPSzLP40sABfeo8B9V5JQgugI89l+HrmHccZA8MD6pjHBnjSoXOzZwMcOnKJE9RKfvXsDuIYQRgDmIYAZiDGEYA5iCGEYA5iGEEsLqlJr6HfKayVCE1y83z2kLyuRxnaaJKuQcATL3FJSgzKc6C7H7HtWRLjHBGxxO+TOoad0pm6q0+PfsSi3JWrqWNsy/Nrfwd1tzCZTPIcbYrrpSzJGZZJCOd5iwUALiMIvAQ4cxfHlx+4ucVgYYQfy75MGex0nnOTPWfZUELAEgm+G+gpYf7QQpe5Tk66NnFxdgdxDACMAcxjADMQQwjgGpabuMi8rKIvC4iR0Xkb8v2q0Rkn4icEpEfiojyYGwY65tqgvQ5AB9wzs2WxRteFJGfA/hLlITjHhORfwbwAEpKJ0uSz/kYGqgsx/CVwDYa4RKHwWEOnnM5XRAhrCxh0NLKgeTgsFLS4vF4PPD+6pW+Ck2VMRzTpY6OnzpOts1ZHmP4PJdnRCKcIGisZzXBhgZWPMxkOEgPRZfqteAAujHew9t5SsNMhktSpgp8vaWLy3MmZ/mzTs7qY8w6/o7vfScrT+6+ZXvF64OHn1H3t5hl7yCuxHwxUqT8zwH4AIAfl+2PAvhEVUc0jHVEVTGIiITKgg1jAJ4F0Adg2jk3nwccwBJqiwuF49KzejrRMNYqVTmIc67onNsDoAfAbQCur/YAC4Xj6hstTDHWFxeUxXLOTQN4HsB7ALSI/G4GrQeAPiNmGOuYapY/6ASQd85Ni0gdgLsB/D1KjvJHAB4DcD+AJ5bb19xcHn19w5X7V5YqaGpk28wU+3IyqT+y7drNsv+921kJcWDoNB+7iSWGXZ5nXesbOKCOKYF77zZdwa+tjWeas1meaZ5W1glMTClqlG3Kun557m3xPD5uInVeHWOuyLPz0wkWSdiQ4hn7mBI8Zz3eXyzK2yWSSh9LSv8ub97CTyXxTkW0o7EyOeGUXhmNarJY3QAeFZEQSnecHznnnhKRYwAeE5G/A/AaSuqLhnFFUY1w3CGUFN0X2/tRikcM44rFZtINIwBzEMMIQJyrruz3ohxMZBzAGQAdAPTIcP1h57I2We5ctjvnOpfbyao6yO8OKnLAOXfrqh/4EmDnsja5WOdij1iGEYA5iGEEcLkc5DuX6biXAjuXtclFOZfLEoMYxnrBHrEMIwBzEMMIYNUdRETuEZE3y626D6728WtBRB4WkTERObLA1iYiz4rIyfL/XO24BhGRrSLyvIgcK7dS/0XZvu7O51K2ha+qg5QLHr8F4KMAdqG0Uu6u1RxDjTwC4J5FtgcBPOecuwbAc+XX64ECgK8453YBuB3AF8ufxXo8n/m28JsB7AFwj4jcjlLV+Tedc1cDmEKpLfyCWO07yG0ATjnn+p1zOZRK5e9d5TGsGOfcCwAWN8Lfi1LLMbCOWo+dc8POuVfLPycBvIFSV+i6O59L2Ra+2g6yBcBCibwlW3XXERudc/NNLiMANl7OwawEEelFqWJ7H9bp+dTSFh6EBekXEVfKma+rvLmINAJ4HMCXnauUMVlP51NLW3gQq+0ggwC2Lnh9JbTqjopINwCU/2ex4TVKWcbpcQDfd879pGxet+cDXPy28NV2kP0ArilnF6IAPgXgyVUew8XmSZRajoEqW4/XAiIiKHWBvuGc+8aCX6278xGRThFpKf883xb+Bv6/LRxY6bk451b1H4CPATiB0jPiX6328Wsc+w8ADAPIo/RM+wCAdpSyPScB/BJA2+UeZ5Xn8l6UHp8OAThY/vex9Xg+AG5Cqe37EIAjAP66bN8B4GUApwD8O4DYhe7bSk0MIwAL0g0jAHMQwwjAHMQwAjAHMYwAzEEMIwBzEMMIwBzEMAL4P/reBAlsXKWPAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 216x216 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import paddle.fluid as fluid\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "image_path = './work/out/img.png' # 图片路径\n",
    "model_path = './work/out/hs-resnet-best' # 模型路径\n",
    "\n",
    "# 加载图像\n",
    "def load_image(image_path):\n",
    "    \"\"\"\n",
    "    功能:\n",
    "        读取图像并转换到输入格式\n",
    "    输入:\n",
    "        image_path - 输入图像路径\n",
    "    输出:\n",
    "        image - 输出图像\n",
    "    \"\"\"\n",
    "    # 读取图像\n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    \n",
    "    # 转换格式\n",
    "    image = image.resize((32, 32), Image.ANTIALIAS) # 调整图像大小\n",
    "    image = np.array(image, dtype=np.float32) # 转换数据格式，数据类型转换为float32\n",
    "\n",
    "    # 减去均值\n",
    "    mean = np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, -1)) # cifar数据集通道平均值\n",
    "    stdv = np.array([0.2471, 0.2435, 0.2616]).reshape((1, 1, -1)) # cifar数据集通道标准差\n",
    "    \n",
    "    image = (image/255.0 - mean) / stdv # 对图像进行归一化\n",
    "    image = image.transpose((2, 0, 1)).astype(np.float32) # 数据格式从HWC转换为CHW，数据类型转换为float32\n",
    "    \n",
    "    # 增加维度\n",
    "    image = np.expand_dims(image, axis=0) # 增加数据维度\n",
    "    \n",
    "    return image\n",
    "\n",
    "# 预测图像\n",
    "with fluid.dygraph.guard():\n",
    "    # 读取图像\n",
    "    image = load_image(image_path)\n",
    "    image = fluid.dygraph.to_variable(image)\n",
    "    \n",
    "    # 加载模型\n",
    "    model = SSRNet()                               # 加载模型\n",
    "    model_dict, _ = fluid.load_dygraph(model_path) # 加载权重\n",
    "    model.set_dict(model_dict)                     # 设置权重\n",
    "    model.eval()                                   # 设置验证\n",
    "    \n",
    "    # 前向传播\n",
    "    infer_time = time.time()              # 推断开始时间\n",
    "    infer = model(image)\n",
    "    infer_time = time.time() - infer_time # 推断结束时间\n",
    "    \n",
    "    # 显示结果\n",
    "    vlist = ['beaver', 'dolphin', 'otter', 'seal', 'whale',\n",
    "             'aquarium fish', 'flatfish', 'ray', 'shark', 'trout',\n",
    "             'orchids', 'poppies', 'roses', 'sunflowers', 'tulips',\n",
    "             'bottles', 'bowls', 'cans', 'cups', 'plates',\n",
    "             'apples', 'mushrooms', 'oranges', 'pears', 'sweet peppers',\n",
    "             'clock', 'keyboard', 'lamp', 'telephone', 'television',\n",
    "             'bed', 'chair', 'couch', 'table', 'wardrobe',\n",
    "             'bee', 'beetle', 'butterfly', 'caterpillar', 'cockroach',\n",
    "             'bear', 'leopard', 'lion', 'tiger', 'wolf',\n",
    "             'bridge', 'castle', 'house', 'road', 'skyscraper',\n",
    "             'cloud', 'forest', 'mountain', 'plain', 'sea',\n",
    "             'camel', 'cattle', 'chimpanzee', 'elephant', 'kangaroo',\n",
    "             'fox', 'porcupine', 'possum', 'raccoon', 'skunk',\n",
    "             'crab', 'lobster', 'snail', 'spider', 'worm',\n",
    "             'baby', 'boy', 'girl', 'man', 'woman',\n",
    "             'crocodile', 'dinosaur', 'lizard', 'snake', 'turtle',\n",
    "             'hamster', 'mouse', 'rabbit', 'shrew', 'squirrel',\n",
    "             'maple', 'oak', 'palm', 'pine', 'willow',\n",
    "             'bicycle', 'bus', 'motorcycle', 'pickup truck', 'train',\n",
    "             'lawn-mower', 'rocket', 'streetcar', 'tank', 'tractor'] # 标签名称列表\n",
    "    vlist.sort() # 字母上升排序\n",
    "    print('infer time: {:f}s, infer value: {}'.format(infer_time, vlist[np.argmax(infer.numpy())]) )\n",
    "    \n",
    "    image = Image.open(image_path) # 打开图像文件\n",
    "    plt.figure(figsize=(3, 3))     # 设置显示大小\n",
    "    plt.imshow(image)              # 设置显示图像\n",
    "    plt.show()                     # 显示图像文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
